From 40f353c92ec132222e849abc071fdb85d4768915 Mon Sep 17 00:00:00 2001 From: ivmarkov Date: Thu, 2 Feb 2023 18:22:21 +0000 Subject: [PATCH 01/72] Support for no_std Support for no_std Further no_std compat --- examples/onoff_light/src/lib.rs | 2 +- examples/onoff_light/src/main.rs | 63 +- examples/speaker/src/lib.rs | 2 +- examples/speaker/src/main.rs | 89 ++- examples/speaker/src/speaker.rs | 69 ++ matter/Cargo.toml | 7 +- matter/src/acl.rs | 441 +++++----- matter/src/cert/asn1_writer.rs | 7 +- matter/src/cert/mod.rs | 20 +- matter/src/cert/printer.rs | 2 +- matter/src/codec/base38.rs | 10 +- matter/src/core.rs | 145 ++-- matter/src/crypto/crypto_dummy.rs | 99 ++- matter/src/crypto/crypto_esp_mbedtls.rs | 18 +- matter/src/crypto/crypto_mbedtls.rs | 23 +- matter/src/crypto/crypto_openssl.rs | 15 +- matter/src/crypto/mod.rs | 27 +- .../data_model/cluster_basic_information.rs | 175 ++-- matter/src/data_model/cluster_on_off.rs | 185 +++-- matter/src/data_model/cluster_template.rs | 54 +- matter/src/data_model/core.rs | 199 +++++ matter/src/data_model/core/mod.rs | 394 --------- matter/src/data_model/core/read.rs | 319 -------- matter/src/data_model/core/subscribe.rs | 142 ---- matter/src/data_model/device_types.rs | 56 +- matter/src/data_model/mod.rs | 3 +- matter/src/data_model/objects/attribute.rs | 116 +-- matter/src/data_model/objects/cluster.rs | 515 ++++++------ matter/src/data_model/objects/dataver.rs | 55 ++ matter/src/data_model/objects/encoder.rs | 476 ++++++++++- matter/src/data_model/objects/endpoint.rs | 137 ++-- matter/src/data_model/objects/handler.rs | 350 ++++++++ matter/src/data_model/objects/mod.rs | 24 +- matter/src/data_model/objects/node.rs | 554 +++++++------ matter/src/data_model/root_endpoint.rs | 108 +++ .../src/data_model/sdm/admin_commissioning.rs | 225 +++--- matter/src/data_model/sdm/failsafe.rs | 37 +- .../data_model/sdm/general_commissioning.rs | 347 ++++---- matter/src/data_model/sdm/noc.rs | 753 ++++++++++-------- matter/src/data_model/sdm/nw_commissioning.rs | 53 +- .../data_model/system_model/access_control.rs | 400 ++++++---- .../src/data_model/system_model/descriptor.rs | 249 +++--- matter/src/error.rs | 54 +- matter/src/fabric.rs | 559 ++++++++----- matter/src/group_keys.rs | 5 +- matter/src/interaction_model/command.rs | 88 -- matter/src/interaction_model/core.rs | 570 +++++++++---- matter/src/interaction_model/messages.rs | 117 +-- matter/src/interaction_model/mod.rs | 68 -- matter/src/interaction_model/read.rs | 42 - matter/src/interaction_model/write.rs | 58 -- matter/src/lib.rs | 8 +- matter/src/mdns.rs | 202 +++-- matter/src/pairing/code.rs | 50 +- matter/src/pairing/qr.rs | 41 +- matter/src/persist.rs | 229 ++++++ matter/src/secure_channel/case.rs | 150 ++-- matter/src/secure_channel/common.rs | 28 +- matter/src/secure_channel/core.rs | 61 +- matter/src/secure_channel/crypto.rs | 49 +- matter/src/secure_channel/crypto_dummy.rs | 73 ++ .../src/secure_channel/crypto_esp_mbedtls.rs | 47 +- matter/src/secure_channel/crypto_mbedtls.rs | 77 +- matter/src/secure_channel/crypto_openssl.rs | 75 +- matter/src/secure_channel/mod.rs | 11 +- matter/src/secure_channel/pake.rs | 171 ++-- matter/src/secure_channel/spake2p.rs | 46 +- matter/src/secure_channel/status_report.rs | 2 +- matter/src/sys/mod.rs | 9 +- matter/src/sys/posix.rs | 96 --- matter/src/tlv/parser.rs | 40 +- matter/src/tlv/traits.rs | 38 +- matter/src/tlv/writer.rs | 26 +- matter/src/transport/exchange.rs | 302 ++++--- matter/src/transport/mgr.rs | 299 ++++--- matter/src/transport/mod.rs | 2 +- matter/src/transport/mrp.rs | 22 +- matter/src/transport/network.rs | 17 +- matter/src/transport/packet.rs | 51 +- matter/src/transport/plain_hdr.rs | 9 +- matter/src/transport/proto_ctx.rs | 43 + matter/src/transport/proto_demux.rs | 95 --- matter/src/transport/proto_hdr.rs | 30 +- matter/src/transport/session.rs | 261 +++--- matter/src/transport/udp.rs | 22 +- matter/src/utils/epoch.rs | 14 + matter/src/utils/mod.rs | 2 + matter/src/utils/parsebuf.rs | 47 +- matter/src/utils/rand.rs | 3 + matter/src/utils/writebuf.rs | 56 +- matter/tests/common/attributes.rs | 3 +- matter/tests/common/commands.rs | 4 +- matter/tests/common/echo_cluster.rs | 300 ++++--- matter/tests/common/handlers.rs | 317 ++++++++ matter/tests/common/im_engine.rs | 214 +++-- matter/tests/common/mod.rs | 1 + matter/tests/data_model/acl_and_dataver.rs | 308 +++---- matter/tests/data_model/attribute_lists.rs | 40 +- matter/tests/data_model/attributes.rs | 176 ++-- matter/tests/data_model/commands.rs | 52 +- matter/tests/data_model/timed_requests.rs | 199 +---- matter/tests/data_model_tests.rs | 2 +- matter/tests/interaction_model.rs | 145 ++-- 103 files changed, 7177 insertions(+), 5914 deletions(-) create mode 100644 examples/speaker/src/speaker.rs create mode 100644 matter/src/data_model/core.rs delete mode 100644 matter/src/data_model/core/mod.rs delete mode 100644 matter/src/data_model/core/read.rs delete mode 100644 matter/src/data_model/core/subscribe.rs create mode 100644 matter/src/data_model/objects/dataver.rs create mode 100644 matter/src/data_model/objects/handler.rs create mode 100644 matter/src/data_model/root_endpoint.rs delete mode 100644 matter/src/interaction_model/command.rs delete mode 100644 matter/src/interaction_model/read.rs delete mode 100644 matter/src/interaction_model/write.rs create mode 100644 matter/src/persist.rs create mode 100644 matter/src/secure_channel/crypto_dummy.rs delete mode 100644 matter/src/sys/posix.rs create mode 100644 matter/src/transport/proto_ctx.rs delete mode 100644 matter/src/transport/proto_demux.rs create mode 100644 matter/src/utils/epoch.rs create mode 100644 matter/src/utils/rand.rs create mode 100644 matter/tests/common/handlers.rs diff --git a/examples/onoff_light/src/lib.rs b/examples/onoff_light/src/lib.rs index 43ca1b11..16264d04 100644 --- a/examples/onoff_light/src/lib.rs +++ b/examples/onoff_light/src/lib.rs @@ -15,4 +15,4 @@ * limitations under the License. */ -pub mod dev_att; +// TODO pub mod dev_att; diff --git a/examples/onoff_light/src/main.rs b/examples/onoff_light/src/main.rs index 1eb5d638..b2bc4484 100644 --- a/examples/onoff_light/src/main.rs +++ b/examples/onoff_light/src/main.rs @@ -15,40 +15,41 @@ * limitations under the License. */ -mod dev_att; -use matter::core::{self, CommissioningData}; -use matter::data_model::cluster_basic_information::BasicInfoConfig; -use matter::data_model::device_types::device_type_add_on_off_light; -use matter::secure_channel::spake2p::VerifierData; +// TODO +// mod dev_att; +// use matter::core::{self, CommissioningData}; +// use matter::data_model::cluster_basic_information::BasicInfoConfig; +// use matter::data_model::device_types::device_type_add_on_off_light; +// use matter::secure_channel::spake2p::VerifierData; fn main() { - env_logger::init(); - let comm_data = CommissioningData { - // TODO: Hard-coded for now - verifier: VerifierData::new_with_pw(123456), - discriminator: 250, - }; + // env_logger::init(); + // let comm_data = CommissioningData { + // // TODO: Hard-coded for now + // verifier: VerifierData::new_with_pw(123456), + // discriminator: 250, + // }; - // vid/pid should match those in the DAC - let dev_info = BasicInfoConfig { - vid: 0xFFF1, - pid: 0x8000, - hw_ver: 2, - sw_ver: 1, - sw_ver_str: "1".to_string(), - serial_no: "aabbccdd".to_string(), - device_name: "OnOff Light".to_string(), - }; - let dev_att = Box::new(dev_att::HardCodedDevAtt::new()); + // // vid/pid should match those in the DAC + // let dev_info = BasicInfoConfig { + // vid: 0xFFF1, + // pid: 0x8000, + // hw_ver: 2, + // sw_ver: 1, + // sw_ver_str: "1".to_string(), + // serial_no: "aabbccdd".to_string(), + // device_name: "OnOff Light".to_string(), + // }; + // let dev_att = Box::new(dev_att::HardCodedDevAtt::new()); - let mut matter = core::Matter::new(dev_info, dev_att, comm_data).unwrap(); - let dm = matter.get_data_model(); - { - let mut node = dm.node.write().unwrap(); - let endpoint = device_type_add_on_off_light(&mut node).unwrap(); - println!("Added OnOff Light Device type at endpoint id: {}", endpoint); - println!("Data Model now is: {}", node); - } + // let mut matter = core::Matter::new(dev_info, dev_att, comm_data).unwrap(); + // let dm = matter.get_data_model(); + // { + // let mut node = dm.node.write().unwrap(); + // let endpoint = device_type_add_on_off_light(&mut node).unwrap(); + // println!("Added OnOff Light Device type at endpoint id: {}", endpoint); + // println!("Data Model now is: {}", node); + // } - matter.start_daemon().unwrap(); + // matter.start_daemon().unwrap(); } diff --git a/examples/speaker/src/lib.rs b/examples/speaker/src/lib.rs index 43ca1b11..16264d04 100644 --- a/examples/speaker/src/lib.rs +++ b/examples/speaker/src/lib.rs @@ -15,4 +15,4 @@ * limitations under the License. */ -pub mod dev_att; +// TODO pub mod dev_att; diff --git a/examples/speaker/src/main.rs b/examples/speaker/src/main.rs index de2a6051..f3b3f7db 100644 --- a/examples/speaker/src/main.rs +++ b/examples/speaker/src/main.rs @@ -15,55 +15,56 @@ * limitations under the License. */ -mod dev_att; -use matter::core::{self, CommissioningData}; -use matter::data_model::cluster_basic_information::BasicInfoConfig; -use matter::data_model::cluster_media_playback::{Commands, MediaPlaybackCluster}; -use matter::data_model::device_types::DEV_TYPE_ON_SMART_SPEAKER; -use matter::secure_channel::spake2p::VerifierData; +// TODO +// mod dev_att; +// use matter::core::{self, CommissioningData}; +// use matter::data_model::cluster_basic_information::BasicInfoConfig; +// use matter::data_model::cluster_media_playback::{Commands, MediaPlaybackCluster}; +// use matter::data_model::device_types::DEV_TYPE_ON_SMART_SPEAKER; +// use matter::secure_channel::spake2p::VerifierData; fn main() { - env_logger::init(); - let comm_data = CommissioningData { - // TODO: Hard-coded for now - verifier: VerifierData::new_with_pw(123456), - discriminator: 250, - }; + // env_logger::init(); + // let comm_data = CommissioningData { + // // TODO: Hard-coded for now + // verifier: VerifierData::new_with_pw(123456), + // discriminator: 250, + // }; - // vid/pid should match those in the DAC - let dev_info = BasicInfoConfig { - vid: 0xFFF1, - pid: 0x8002, - hw_ver: 2, - sw_ver: 1, - sw_ver_str: "1".to_string(), - serial_no: "aabbccdd".to_string(), - device_name: "Smart Speaker".to_string(), - }; - let dev_att = Box::new(dev_att::HardCodedDevAtt::new()); + // // vid/pid should match those in the DAC + // let dev_info = BasicInfoConfig { + // vid: 0xFFF1, + // pid: 0x8002, + // hw_ver: 2, + // sw_ver: 1, + // sw_ver_str: "1".to_string(), + // serial_no: "aabbccdd".to_string(), + // device_name: "Smart Speaker".to_string(), + // }; + // let dev_att = Box::new(dev_att::HardCodedDevAtt::new()); - let mut matter = core::Matter::new(dev_info, dev_att, comm_data).unwrap(); - let dm = matter.get_data_model(); - { - let mut node = dm.node.write().unwrap(); + // let mut matter = core::Matter::new(dev_info, dev_att, comm_data).unwrap(); + // let dm = matter.get_data_model(); + // { + // let mut node = dm.node.write().unwrap(); - let endpoint_audio = node.add_endpoint(DEV_TYPE_ON_SMART_SPEAKER).unwrap(); - let mut media_playback_cluster = MediaPlaybackCluster::new().unwrap(); + // let endpoint_audio = node.add_endpoint(DEV_TYPE_ON_SMART_SPEAKER).unwrap(); + // let mut media_playback_cluster = MediaPlaybackCluster::new().unwrap(); - // Add some callbacks - let play_callback = Box::new(|| log::info!("Comamnd [Play] handled with callback.")); - let pause_callback = Box::new(|| log::info!("Comamnd [Pause] handled with callback.")); - let stop_callback = Box::new(|| log::info!("Comamnd [Stop] handled with callback.")); - let start_over_callback = - Box::new(|| log::info!("Comamnd [StartOver] handled with callback.")); - media_playback_cluster.add_callback(Commands::Play, play_callback); - media_playback_cluster.add_callback(Commands::Pause, pause_callback); - media_playback_cluster.add_callback(Commands::Stop, stop_callback); - media_playback_cluster.add_callback(Commands::StartOver, start_over_callback); + // // Add some callbacks + // let play_callback = Box::new(|| log::info!("Comamnd [Play] handled with callback.")); + // let pause_callback = Box::new(|| log::info!("Comamnd [Pause] handled with callback.")); + // let stop_callback = Box::new(|| log::info!("Comamnd [Stop] handled with callback.")); + // let start_over_callback = + // Box::new(|| log::info!("Comamnd [StartOver] handled with callback.")); + // media_playback_cluster.add_callback(Commands::Play, play_callback); + // media_playback_cluster.add_callback(Commands::Pause, pause_callback); + // media_playback_cluster.add_callback(Commands::Stop, stop_callback); + // media_playback_cluster.add_callback(Commands::StartOver, start_over_callback); - node.add_cluster(endpoint_audio, media_playback_cluster) - .unwrap(); - println!("Added Speaker type at endpoint id: {}", endpoint_audio) - } - matter.start_daemon().unwrap(); + // node.add_cluster(endpoint_audio, media_playback_cluster) + // .unwrap(); + // println!("Added Speaker type at endpoint id: {}", endpoint_audio) + // } + // matter.start_daemon().unwrap(); } diff --git a/examples/speaker/src/speaker.rs b/examples/speaker/src/speaker.rs new file mode 100644 index 00000000..de2a6051 --- /dev/null +++ b/examples/speaker/src/speaker.rs @@ -0,0 +1,69 @@ +/* + * + * Copyright (c) 2020-2022 Project CHIP Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +mod dev_att; +use matter::core::{self, CommissioningData}; +use matter::data_model::cluster_basic_information::BasicInfoConfig; +use matter::data_model::cluster_media_playback::{Commands, MediaPlaybackCluster}; +use matter::data_model::device_types::DEV_TYPE_ON_SMART_SPEAKER; +use matter::secure_channel::spake2p::VerifierData; + +fn main() { + env_logger::init(); + let comm_data = CommissioningData { + // TODO: Hard-coded for now + verifier: VerifierData::new_with_pw(123456), + discriminator: 250, + }; + + // vid/pid should match those in the DAC + let dev_info = BasicInfoConfig { + vid: 0xFFF1, + pid: 0x8002, + hw_ver: 2, + sw_ver: 1, + sw_ver_str: "1".to_string(), + serial_no: "aabbccdd".to_string(), + device_name: "Smart Speaker".to_string(), + }; + let dev_att = Box::new(dev_att::HardCodedDevAtt::new()); + + let mut matter = core::Matter::new(dev_info, dev_att, comm_data).unwrap(); + let dm = matter.get_data_model(); + { + let mut node = dm.node.write().unwrap(); + + let endpoint_audio = node.add_endpoint(DEV_TYPE_ON_SMART_SPEAKER).unwrap(); + let mut media_playback_cluster = MediaPlaybackCluster::new().unwrap(); + + // Add some callbacks + let play_callback = Box::new(|| log::info!("Comamnd [Play] handled with callback.")); + let pause_callback = Box::new(|| log::info!("Comamnd [Pause] handled with callback.")); + let stop_callback = Box::new(|| log::info!("Comamnd [Stop] handled with callback.")); + let start_over_callback = + Box::new(|| log::info!("Comamnd [StartOver] handled with callback.")); + media_playback_cluster.add_callback(Commands::Play, play_callback); + media_playback_cluster.add_callback(Commands::Pause, pause_callback); + media_playback_cluster.add_callback(Commands::Stop, stop_callback); + media_playback_cluster.add_callback(Commands::StartOver, start_over_callback); + + node.add_cluster(endpoint_audio, media_playback_cluster) + .unwrap(); + println!("Added Speaker type at endpoint id: {}", endpoint_audio) + } + matter.start_daemon().unwrap(); +} diff --git a/matter/Cargo.toml b/matter/Cargo.toml index ab965ba5..0987e10d 100644 --- a/matter/Cargo.toml +++ b/matter/Cargo.toml @@ -15,7 +15,9 @@ name = "matter" path = "src/lib.rs" [features] -default = ["crypto_mbedtls"] +default = ["std", "crypto_mbedtls"] +std = [] +nightly = [] crypto_openssl = ["openssl", "foreign-types", "hmac", "sha2"] crypto_mbedtls = ["mbedtls"] crypto_esp_mbedtls = ["esp-idf-sys"] @@ -34,7 +36,7 @@ num-traits = "0.2.15" log = { version = "0.4.17", features = ["max_level_debug", "release_max_level_debug"] } env_logger = { version = "0.10.0", default-features = false, features = [] } rand = "0.8.5" -esp-idf-sys = { version = "0.32", features = ["binstart"], optional = true } +esp-idf-sys = { version = "0.32", optional = true } subtle = "2.4.1" colored = "2.0.0" smol = "1.3.0" @@ -42,6 +44,7 @@ owning_ref = "0.4.1" safemem = "0.3.3" chrono = { version = "0.4.23", default-features = false, features = ["clock", "std"] } async-channel = "1.8" +strum = { version = "0.24", features = ["derive"], no-default-feature = true } # crypto openssl = { git = "https://github.com/sfackler/rust-openssl", optional = true } diff --git a/matter/src/acl.rs b/matter/src/acl.rs index 708ddee6..8f965e13 100644 --- a/matter/src/acl.rs +++ b/matter/src/acl.rs @@ -15,19 +15,16 @@ * limitations under the License. */ -use std::{ - fmt::Display, - sync::{Arc, Mutex, MutexGuard, RwLock}, -}; +use core::{cell::RefCell, fmt::Display}; use crate::{ data_model::objects::{Access, ClusterId, EndptId, Privilege}, error::Error, fabric, interaction_model::messages::GenericPath, - sys::Psm, + persist::Psm, tlv::{FromTLV, TLVElement, TLVList, TLVWriter, TagType, ToTLV}, - transport::session::MAX_CAT_IDS_PER_NOC, + transport::session::{Session, SessionMode, MAX_CAT_IDS_PER_NOC}, utils::writebuf::WriteBuf, }; use log::error; @@ -160,7 +157,7 @@ impl Display for AccessorSubjects { } /// The Accessor Object -pub struct Accessor { +pub struct Accessor<'a> { /// The fabric index of the accessor pub fab_idx: u8, /// Accessor's subject: could be node-id, NoC CAT, group id @@ -168,15 +165,37 @@ pub struct Accessor { /// The Authmode of this session auth_mode: AuthMode, // TODO: Is this the right place for this though, or should we just use a global-acl-handle-get - acl_mgr: Arc, + acl_mgr: &'a RefCell, } -impl Accessor { - pub fn new( +impl<'a> Accessor<'a> { + pub fn for_session(session: &Session, acl_mgr: &'a RefCell) -> Self { + match session.get_session_mode() { + SessionMode::Case(c) => { + let mut subject = + AccessorSubjects::new(session.get_peer_node_id().unwrap_or_default()); + for i in c.cat_ids { + if i != 0 { + let _ = subject.add_catid(i); + } + } + Accessor::new(c.fab_idx, subject, AuthMode::Case, &acl_mgr) + } + SessionMode::Pase => { + Accessor::new(0, AccessorSubjects::new(1), AuthMode::Pase, &acl_mgr) + } + + SessionMode::PlainText => { + Accessor::new(0, AccessorSubjects::new(1), AuthMode::Invalid, &acl_mgr) + } + } + } + + pub const fn new( fab_idx: u8, subjects: AccessorSubjects, auth_mode: AuthMode, - acl_mgr: Arc, + acl_mgr: &'a RefCell, ) -> Self { Self { fab_idx, @@ -188,9 +207,9 @@ impl Accessor { } #[derive(Debug)] -pub struct AccessDesc<'a> { +pub struct AccessDesc { /// The object to be acted upon - path: &'a GenericPath, + path: GenericPath, /// The target permissions target_perms: Option, // The operation being done @@ -200,8 +219,8 @@ pub struct AccessDesc<'a> { /// Access Request Object pub struct AccessReq<'a> { - accessor: &'a Accessor, - object: AccessDesc<'a>, + accessor: &'a Accessor<'a>, + object: AccessDesc, } impl<'a> AccessReq<'a> { @@ -209,7 +228,7 @@ impl<'a> AccessReq<'a> { /// /// An access request specifies the _accessor_ attempting to access _path_ /// with _operation_ - pub fn new(accessor: &'a Accessor, path: &'a GenericPath, operation: Access) -> Self { + pub fn new(accessor: &'a Accessor, path: GenericPath, operation: Access) -> Self { AccessReq { accessor, object: AccessDesc { @@ -220,6 +239,10 @@ impl<'a> AccessReq<'a> { } } + pub fn operation(&self) -> Access { + self.object.operation + } + /// Add target's permissions to the request /// /// The permissions that are associated with the target (identified by the @@ -234,7 +257,7 @@ impl<'a> AccessReq<'a> { /// _accessor_ the necessary privileges to access the target as per its /// permissions pub fn allow(&self) -> bool { - self.accessor.acl_mgr.allow(self) + self.accessor.acl_mgr.borrow().allow(self) } } @@ -369,116 +392,36 @@ impl AclEntry { const MAX_ACL_ENTRIES: usize = ENTRIES_PER_FABRIC * fabric::MAX_SUPPORTED_FABRICS; type AclEntries = [Option; MAX_ACL_ENTRIES]; -#[derive(ToTLV, FromTLV, Debug)] -struct AclMgrInner { - entries: AclEntries, -} - const ACL_KV_ENTRY: &str = "acl"; const ACL_KV_MAX_SIZE: usize = 300; -impl AclMgrInner { - pub fn store(&self, psm: &MutexGuard) -> Result<(), Error> { - let mut acl_tlvs = [0u8; ACL_KV_MAX_SIZE]; - let mut wb = WriteBuf::new(&mut acl_tlvs, ACL_KV_MAX_SIZE); - let mut tw = TLVWriter::new(&mut wb); - self.entries.to_tlv(&mut tw, TagType::Anonymous)?; - psm.set_kv_slice(ACL_KV_ENTRY, wb.as_slice()) - } - - pub fn load(psm: &MutexGuard) -> Result { - let mut acl_tlvs = Vec::new(); - psm.get_kv_slice(ACL_KV_ENTRY, &mut acl_tlvs)?; - let root = TLVList::new(&acl_tlvs) - .iter() - .next() - .ok_or(Error::Invalid)?; - - Ok(Self { - entries: AclEntries::from_tlv(&root)?, - }) - } - - /// Traverse fabric specific entries to find the index - /// - /// If the ACL Mgr has 3 entries with fabric indexes, 1, 2, 1, then the list - /// index 1 for Fabric 1 in the ACL Mgr will be the actual index 2 (starting from 0) - fn for_index_in_fabric( - &mut self, - index: u8, - fab_idx: u8, - ) -> Result<&mut Option, Error> { - // Can't use flatten as we need to borrow the Option<> not the 'AclEntry' - for (curr_index, entry) in self - .entries - .iter_mut() - .filter(|e| e.filter(|e1| e1.fab_idx == Some(fab_idx)).is_some()) - .enumerate() - { - if curr_index == index as usize { - return Ok(entry); - } - } - Err(Error::NotFound) - } -} pub struct AclMgr { - inner: RwLock, - // The Option<> is solely because test execution is faster - // Doing this here adds the least overhead during ACL verification - psm: Option>>, + entries: AclEntries, + changed: bool, } impl AclMgr { - pub fn new() -> Result { - AclMgr::new_with(true) - } - - pub fn new_with(psm_support: bool) -> Result { + pub const fn new() -> Self { const INIT: Option = None; - let mut psm = None; - let inner = if !psm_support { - AclMgrInner { - entries: [INIT; MAX_ACL_ENTRIES], - } - } else { - let psm_handle = Psm::get()?; - let inner = { - let psm_lock = psm_handle.lock().unwrap(); - AclMgrInner::load(&psm_lock) - }; - - psm = Some(psm_handle); - inner.unwrap_or({ - // Error loading from PSM - AclMgrInner { - entries: [INIT; MAX_ACL_ENTRIES], - } - }) - }; - Ok(Self { - inner: RwLock::new(inner), - psm, - }) + Self { + entries: [INIT; MAX_ACL_ENTRIES], + changed: false, + } } - pub fn erase_all(&self) { - let mut inner = self.inner.write().unwrap(); + pub fn erase_all(&mut self) -> Result<(), Error> { for i in 0..MAX_ACL_ENTRIES { - inner.entries[i] = None; - } - if let Some(psm) = self.psm.as_ref() { - let psm = psm.lock().unwrap(); - let _ = inner.store(&psm).map_err(|e| { - error!("Error in storing ACLs {}", e); - }); + self.entries[i] = None; } + + self.changed = true; + + Ok(()) } - pub fn add(&self, entry: AclEntry) -> Result<(), Error> { - let mut inner = self.inner.write().unwrap(); - let cnt = inner + pub fn add(&mut self, entry: AclEntry) -> Result<(), Error> { + let cnt = self .entries .iter() .flatten() @@ -487,76 +430,60 @@ impl AclMgr { if cnt >= ENTRIES_PER_FABRIC { return Err(Error::NoSpace); } - let index = inner + let index = self .entries .iter() .position(|a| a.is_none()) .ok_or(Error::NoSpace)?; - inner.entries[index] = Some(entry); + self.entries[index] = Some(entry); - if let Some(psm) = self.psm.as_ref() { - let psm = psm.lock().unwrap(); - inner.store(&psm) - } else { - Ok(()) - } + self.changed = true; + + Ok(()) } // Since the entries are fabric-scoped, the index is only for entries with the matching fabric index - pub fn edit(&self, index: u8, fab_idx: u8, new: AclEntry) -> Result<(), Error> { - let mut inner = self.inner.write().unwrap(); - let old = inner.for_index_in_fabric(index, fab_idx)?; + pub fn edit(&mut self, index: u8, fab_idx: u8, new: AclEntry) -> Result<(), Error> { + let old = self.for_index_in_fabric(index, fab_idx)?; *old = Some(new); - if let Some(psm) = self.psm.as_ref() { - let psm = psm.lock().unwrap(); - inner.store(&psm) - } else { - Ok(()) - } + self.changed = true; + + Ok(()) } - pub fn delete(&self, index: u8, fab_idx: u8) -> Result<(), Error> { - let mut inner = self.inner.write().unwrap(); - let old = inner.for_index_in_fabric(index, fab_idx)?; + pub fn delete(&mut self, index: u8, fab_idx: u8) -> Result<(), Error> { + let old = self.for_index_in_fabric(index, fab_idx)?; *old = None; - if let Some(psm) = self.psm.as_ref() { - let psm = psm.lock().unwrap(); - inner.store(&psm) - } else { - Ok(()) - } - } + self.changed = true; - pub fn delete_for_fabric(&self, fab_idx: u8) -> Result<(), Error> { - let mut inner = self.inner.write().unwrap(); + Ok(()) + } + pub fn delete_for_fabric(&mut self, fab_idx: u8) -> Result<(), Error> { for i in 0..MAX_ACL_ENTRIES { - if inner.entries[i] + if self.entries[i] .filter(|e| e.fab_idx == Some(fab_idx)) .is_some() { - inner.entries[i] = None; + self.entries[i] = None; } } - if let Some(psm) = self.psm.as_ref() { - let psm = psm.lock().unwrap(); - inner.store(&psm) - } else { - Ok(()) - } + self.changed = true; + + Ok(()) } pub fn for_each_acl(&self, mut f: T) -> Result<(), Error> where - T: FnMut(&AclEntry), + T: FnMut(&AclEntry) -> Result<(), Error>, { - let inner = self.inner.read().unwrap(); - for entry in inner.entries.iter().flatten() { - f(entry) + for entry in self.entries.iter().flatten() { + f(entry)?; } + Ok(()) } @@ -565,8 +492,7 @@ impl AclMgr { if req.accessor.auth_mode == AuthMode::Pase { return true; } - let inner = self.inner.read().unwrap(); - for e in inner.entries.iter().flatten() { + for e in self.entries.iter().flatten() { if e.allow(req) { return true; } @@ -578,13 +504,102 @@ impl AclMgr { error!("{}", self); false } + + pub fn store(&mut self, mut psm: T) -> Result<(), Error> + where + T: Psm, + { + if self.changed { + let mut buf = [0u8; ACL_KV_MAX_SIZE]; + let mut wb = WriteBuf::new(&mut buf); + let mut tw = TLVWriter::new(&mut wb); + self.entries.to_tlv(&mut tw, TagType::Anonymous)?; + psm.set_kv_slice(ACL_KV_ENTRY, wb.into_slice())?; + + self.changed = false; + } + + Ok(()) + } + + pub fn load(&mut self, psm: T) -> Result<(), Error> + where + T: Psm, + { + let mut buf = [0u8; ACL_KV_MAX_SIZE]; + let acl_tlvs = psm.get_kv_slice(ACL_KV_ENTRY, &mut buf)?; + let root = TLVList::new(acl_tlvs).iter().next().ok_or(Error::Invalid)?; + + self.entries = AclEntries::from_tlv(&root)?; + self.changed = false; + + Ok(()) + } + + #[cfg(feature = "nightly")] + pub async fn store_async(&mut self, mut psm: T) -> Result<(), Error> + where + T: crate::persist::asynch::AsyncPsm, + { + if self.changed { + let mut buf = [0u8; ACL_KV_MAX_SIZE]; + let mut wb = WriteBuf::new(&mut buf); + let mut tw = TLVWriter::new(&mut wb); + self.entries.to_tlv(&mut tw, TagType::Anonymous)?; + psm.set_kv_slice(ACL_KV_ENTRY, wb.into_slice()).await?; + + self.changed = false; + } + + Ok(()) + } + + #[cfg(feature = "nightly")] + pub async fn load_async(&mut self, psm: T) -> Result<(), Error> + where + T: crate::persist::asynch::AsyncPsm, + { + let mut buf = [0u8; ACL_KV_MAX_SIZE]; + let acl_tlvs = psm.get_kv_slice(ACL_KV_ENTRY, &mut buf).await?; + let root = TLVList::new(&acl_tlvs) + .iter() + .next() + .ok_or(Error::Invalid)?; + + self.entries = AclEntries::from_tlv(&root)?; + self.changed = false; + + Ok(()) + } + + /// Traverse fabric specific entries to find the index + /// + /// If the ACL Mgr has 3 entries with fabric indexes, 1, 2, 1, then the list + /// index 1 for Fabric 1 in the ACL Mgr will be the actual index 2 (starting from 0) + fn for_index_in_fabric( + &mut self, + index: u8, + fab_idx: u8, + ) -> Result<&mut Option, Error> { + // Can't use flatten as we need to borrow the Option<> not the 'AclEntry' + for (curr_index, entry) in self + .entries + .iter_mut() + .filter(|e| e.filter(|e1| e1.fab_idx == Some(fab_idx)).is_some()) + .enumerate() + { + if curr_index == index as usize { + return Ok(entry); + } + } + Err(Error::NotFound) + } } -impl std::fmt::Display for AclMgr { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - let inner = self.inner.read().unwrap(); +impl core::fmt::Display for AclMgr { + fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { write!(f, "ACLS: [")?; - for i in inner.entries.iter().flatten() { + for i in self.entries.iter().flatten() { write!(f, " {{ {:?} }}, ", i)?; } write!(f, "]") @@ -594,22 +609,23 @@ impl std::fmt::Display for AclMgr { #[cfg(test)] #[allow(clippy::bool_assert_comparison)] mod tests { + use core::cell::RefCell; + use crate::{ acl::{gen_noc_cat, AccessorSubjects}, data_model::objects::{Access, Privilege}, interaction_model::messages::GenericPath, }; - use std::sync::Arc; use super::{AccessReq, Accessor, AclEntry, AclMgr, AuthMode, Target}; #[test] fn test_basic_empty_subject_target() { - let am = Arc::new(AclMgr::new_with(false).unwrap()); - am.erase_all(); - let accessor = Accessor::new(2, AccessorSubjects::new(112233), AuthMode::Case, am.clone()); + let am = RefCell::new(AclMgr::new()); + am.borrow_mut().erase_all().unwrap(); + let accessor = Accessor::new(2, AccessorSubjects::new(112233), AuthMode::Case, &am); let path = GenericPath::new(Some(1), Some(1234), None); - let mut req = AccessReq::new(&accessor, &path, Access::READ); + let mut req = AccessReq::new(&accessor, path, Access::READ); req.set_target_perms(Access::RWVA); // Default deny @@ -617,46 +633,46 @@ mod tests { // Deny for session mode mismatch let new = AclEntry::new(1, Privilege::VIEW, AuthMode::Pase); - am.add(new).unwrap(); + am.borrow_mut().add(new).unwrap(); assert_eq!(req.allow(), false); // Deny for fab idx mismatch let new = AclEntry::new(1, Privilege::VIEW, AuthMode::Case); - am.add(new).unwrap(); + am.borrow_mut().add(new).unwrap(); assert_eq!(req.allow(), false); // Allow let new = AclEntry::new(2, Privilege::VIEW, AuthMode::Case); - am.add(new).unwrap(); + am.borrow_mut().add(new).unwrap(); assert_eq!(req.allow(), true); } #[test] fn test_subject() { - let am = Arc::new(AclMgr::new_with(false).unwrap()); - am.erase_all(); - let accessor = Accessor::new(2, AccessorSubjects::new(112233), AuthMode::Case, am.clone()); + let am = RefCell::new(AclMgr::new()); + am.borrow_mut().erase_all().unwrap(); + let accessor = Accessor::new(2, AccessorSubjects::new(112233), AuthMode::Case, &am); let path = GenericPath::new(Some(1), Some(1234), None); - let mut req = AccessReq::new(&accessor, &path, Access::READ); + let mut req = AccessReq::new(&accessor, path, Access::READ); req.set_target_perms(Access::RWVA); // Deny for subject mismatch let mut new = AclEntry::new(2, Privilege::VIEW, AuthMode::Case); new.add_subject(112232).unwrap(); - am.add(new).unwrap(); + am.borrow_mut().add(new).unwrap(); assert_eq!(req.allow(), false); // Allow for subject match - target is wildcard let mut new = AclEntry::new(2, Privilege::VIEW, AuthMode::Case); new.add_subject(112233).unwrap(); - am.add(new).unwrap(); + am.borrow_mut().add(new).unwrap(); assert_eq!(req.allow(), true); } #[test] fn test_cat() { - let am = Arc::new(AclMgr::new_with(false).unwrap()); - am.erase_all(); + let am = RefCell::new(AclMgr::new()); + am.borrow_mut().erase_all().unwrap(); let allow_cat = 0xABCD; let disallow_cat = 0xCAFE; @@ -666,35 +682,35 @@ mod tests { let mut subjects = AccessorSubjects::new(112233); subjects.add_catid(gen_noc_cat(allow_cat, v2)).unwrap(); - let accessor = Accessor::new(2, subjects, AuthMode::Case, am.clone()); + let accessor = Accessor::new(2, subjects, AuthMode::Case, &am); let path = GenericPath::new(Some(1), Some(1234), None); - let mut req = AccessReq::new(&accessor, &path, Access::READ); + let mut req = AccessReq::new(&accessor, path, Access::READ); req.set_target_perms(Access::RWVA); // Deny for CAT id mismatch let mut new = AclEntry::new(2, Privilege::VIEW, AuthMode::Case); new.add_subject_catid(gen_noc_cat(disallow_cat, v2)) .unwrap(); - am.add(new).unwrap(); + am.borrow_mut().add(new).unwrap(); assert_eq!(req.allow(), false); // Deny of CAT version mismatch let mut new = AclEntry::new(2, Privilege::VIEW, AuthMode::Case); new.add_subject_catid(gen_noc_cat(allow_cat, v3)).unwrap(); - am.add(new).unwrap(); + am.borrow_mut().add(new).unwrap(); assert_eq!(req.allow(), false); // Allow for CAT match let mut new = AclEntry::new(2, Privilege::VIEW, AuthMode::Case); new.add_subject_catid(gen_noc_cat(allow_cat, v2)).unwrap(); - am.add(new).unwrap(); + am.borrow_mut().add(new).unwrap(); assert_eq!(req.allow(), true); } #[test] fn test_cat_version() { - let am = Arc::new(AclMgr::new_with(false).unwrap()); - am.erase_all(); + let am = RefCell::new(AclMgr::new()); + am.borrow_mut().erase_all().unwrap(); let allow_cat = 0xABCD; let disallow_cat = 0xCAFE; @@ -704,32 +720,32 @@ mod tests { let mut subjects = AccessorSubjects::new(112233); subjects.add_catid(gen_noc_cat(allow_cat, v3)).unwrap(); - let accessor = Accessor::new(2, subjects, AuthMode::Case, am.clone()); + let accessor = Accessor::new(2, subjects, AuthMode::Case, &am); let path = GenericPath::new(Some(1), Some(1234), None); - let mut req = AccessReq::new(&accessor, &path, Access::READ); + let mut req = AccessReq::new(&accessor, path, Access::READ); req.set_target_perms(Access::RWVA); // Deny for CAT id mismatch let mut new = AclEntry::new(2, Privilege::VIEW, AuthMode::Case); new.add_subject_catid(gen_noc_cat(disallow_cat, v2)) .unwrap(); - am.add(new).unwrap(); + am.borrow_mut().add(new).unwrap(); assert_eq!(req.allow(), false); // Allow for CAT match and version more than ACL version let mut new = AclEntry::new(2, Privilege::VIEW, AuthMode::Case); new.add_subject_catid(gen_noc_cat(allow_cat, v2)).unwrap(); - am.add(new).unwrap(); + am.borrow_mut().add(new).unwrap(); assert_eq!(req.allow(), true); } #[test] fn test_target() { - let am = Arc::new(AclMgr::new_with(false).unwrap()); - am.erase_all(); - let accessor = Accessor::new(2, AccessorSubjects::new(112233), AuthMode::Case, am.clone()); + let am = RefCell::new(AclMgr::new()); + am.borrow_mut().erase_all().unwrap(); + let accessor = Accessor::new(2, AccessorSubjects::new(112233), AuthMode::Case, &am); let path = GenericPath::new(Some(1), Some(1234), None); - let mut req = AccessReq::new(&accessor, &path, Access::READ); + let mut req = AccessReq::new(&accessor, path, Access::READ); req.set_target_perms(Access::RWVA); // Deny for target mismatch @@ -740,7 +756,7 @@ mod tests { device_type: None, }) .unwrap(); - am.add(new).unwrap(); + am.borrow_mut().add(new).unwrap(); assert_eq!(req.allow(), false); // Allow for cluster match - subject wildcard @@ -751,11 +767,11 @@ mod tests { device_type: None, }) .unwrap(); - am.add(new).unwrap(); + am.borrow_mut().add(new).unwrap(); assert_eq!(req.allow(), true); // Clean Slate - am.erase_all(); + am.borrow_mut().erase_all().unwrap(); // Allow for endpoint match - subject wildcard let mut new = AclEntry::new(2, Privilege::VIEW, AuthMode::Case); @@ -765,11 +781,11 @@ mod tests { device_type: None, }) .unwrap(); - am.add(new).unwrap(); + am.borrow_mut().add(new).unwrap(); assert_eq!(req.allow(), true); // Clean Slate - am.erase_all(); + am.borrow_mut().erase_all().unwrap(); // Allow for exact match let mut new = AclEntry::new(2, Privilege::VIEW, AuthMode::Case); @@ -780,16 +796,15 @@ mod tests { }) .unwrap(); new.add_subject(112233).unwrap(); - am.add(new).unwrap(); + am.borrow_mut().add(new).unwrap(); assert_eq!(req.allow(), true); } #[test] fn test_privilege() { - let am = Arc::new(AclMgr::new_with(false).unwrap()); - am.erase_all(); - - let accessor = Accessor::new(2, AccessorSubjects::new(112233), AuthMode::Case, am.clone()); + let am = RefCell::new(AclMgr::new()); + am.borrow_mut().erase_all().unwrap(); + let accessor = Accessor::new(2, AccessorSubjects::new(112233), AuthMode::Case, &am); let path = GenericPath::new(Some(1), Some(1234), None); // Create an Exact Match ACL with View privilege @@ -801,10 +816,10 @@ mod tests { }) .unwrap(); new.add_subject(112233).unwrap(); - am.add(new).unwrap(); + am.borrow_mut().add(new).unwrap(); // Write on an RWVA without admin access - deny - let mut req = AccessReq::new(&accessor, &path, Access::WRITE); + let mut req = AccessReq::new(&accessor, path, Access::WRITE); req.set_target_perms(Access::RWVA); assert_eq!(req.allow(), false); @@ -817,40 +832,40 @@ mod tests { }) .unwrap(); new.add_subject(112233).unwrap(); - am.add(new).unwrap(); + am.borrow_mut().add(new).unwrap(); // Write on an RWVA with admin access - allow - let mut req = AccessReq::new(&accessor, &path, Access::WRITE); + let mut req = AccessReq::new(&accessor, path, Access::WRITE); req.set_target_perms(Access::RWVA); assert_eq!(req.allow(), true); } #[test] fn test_delete_for_fabric() { - let am = Arc::new(AclMgr::new_with(false).unwrap()); - am.erase_all(); + let am = RefCell::new(AclMgr::new()); + am.borrow_mut().erase_all().unwrap(); let path = GenericPath::new(Some(1), Some(1234), None); - let accessor2 = Accessor::new(2, AccessorSubjects::new(112233), AuthMode::Case, am.clone()); - let mut req2 = AccessReq::new(&accessor2, &path, Access::READ); + let accessor2 = Accessor::new(2, AccessorSubjects::new(112233), AuthMode::Case, &am); + let mut req2 = AccessReq::new(&accessor2, path, Access::READ); req2.set_target_perms(Access::RWVA); - let accessor3 = Accessor::new(3, AccessorSubjects::new(112233), AuthMode::Case, am.clone()); - let mut req3 = AccessReq::new(&accessor3, &path, Access::READ); + let accessor3 = Accessor::new(3, AccessorSubjects::new(112233), AuthMode::Case, &am); + let mut req3 = AccessReq::new(&accessor3, path, Access::READ); req3.set_target_perms(Access::RWVA); // Allow for subject match - target is wildcard - Fabric idx 2 let mut new = AclEntry::new(2, Privilege::VIEW, AuthMode::Case); new.add_subject(112233).unwrap(); - am.add(new).unwrap(); + am.borrow_mut().add(new).unwrap(); // Allow for subject match - target is wildcard - Fabric idx 3 let mut new = AclEntry::new(3, Privilege::VIEW, AuthMode::Case); new.add_subject(112233).unwrap(); - am.add(new).unwrap(); + am.borrow_mut().add(new).unwrap(); // Req for Fabric idx 2 gets denied, and that for Fabric idx 3 is allowed assert_eq!(req2.allow(), true); assert_eq!(req3.allow(), true); - am.delete_for_fabric(2).unwrap(); + am.borrow_mut().delete_for_fabric(2).unwrap(); assert_eq!(req2.allow(), false); assert_eq!(req3.allow(), true); } diff --git a/matter/src/cert/asn1_writer.rs b/matter/src/cert/asn1_writer.rs index b3bb13ce..ae2ced83 100644 --- a/matter/src/cert/asn1_writer.rs +++ b/matter/src/cert/asn1_writer.rs @@ -18,6 +18,7 @@ use super::{CertConsumer, MAX_DEPTH}; use crate::error::Error; use chrono::{Datelike, TimeZone, Utc}; +use core::fmt::Write; use log::warn; #[derive(Debug)] @@ -279,10 +280,12 @@ impl<'a> CertConsumer for ASN1Writer<'a> { if dt.year() >= 2050 { // If year is >= 2050, ASN.1 requires it to be Generalised Time - let time_str = format!("{}Z", dt.format("%Y%m%d%H%M%S")); + let mut time_str = heapless::String::<32>::new(); + write!(&mut time_str, "{}Z", dt.format("%Y%m%d%H%M%S")).unwrap(); self.write_str(0x18, time_str.as_bytes()) } else { - let time_str = format!("{}Z", dt.format("%y%m%d%H%M%S")); + let mut time_str = heapless::String::<32>::new(); + write!(&mut time_str, "{}Z", dt.format("%y%m%d%H%M%S")).unwrap(); self.write_str(0x17, time_str.as_bytes()) } } diff --git a/matter/src/cert/mod.rs b/matter/src/cert/mod.rs index 360ce31f..664125b5 100644 --- a/matter/src/cert/mod.rs +++ b/matter/src/cert/mod.rs @@ -15,14 +15,17 @@ * limitations under the License. */ -use std::fmt; +use core::fmt; + +extern crate alloc; use crate::{ - crypto::{CryptoKeyPair, KeyPair}, + crypto::KeyPair, error::Error, tlv::{self, FromTLV, TLVArrayOwned, TLVElement, TLVWriter, TagType, ToTLV}, utils::writebuf::WriteBuf, }; +use alloc::{format, string::String, vec::Vec}; use log::error; use num_derive::FromPrimitive; @@ -591,10 +594,10 @@ impl Cert { } pub fn as_tlv(&self, buf: &mut [u8]) -> Result { - let mut wb = WriteBuf::new(buf, buf.len()); + let mut wb = WriteBuf::new(buf); let mut tw = TLVWriter::new(&mut wb); self.to_tlv(&mut tw, TagType::Anonymous)?; - Ok(wb.as_slice().len()) + Ok(wb.into_slice().len()) } pub fn as_asn1(&self, buf: &mut [u8]) -> Result { @@ -731,6 +734,8 @@ mod printer; #[cfg(test)] mod tests { + use log::info; + use crate::cert::Cert; use crate::error::Error; use crate::tlv::{self, FromTLV, TLVWriter, TagType, ToTLV}; @@ -811,15 +816,14 @@ mod tests { ]; for input in test_input.iter() { - println!("Testing next input..."); + info!("Testing next input..."); let root = tlv::get_root_node(input).unwrap(); let cert = Cert::from_tlv(&root).unwrap(); let mut buf = [0u8; 1024]; - let buf_len = buf.len(); - let mut wb = WriteBuf::new(&mut buf, buf_len); + let mut wb = WriteBuf::new(&mut buf); let mut tw = TLVWriter::new(&mut wb); cert.to_tlv(&mut tw, TagType::Anonymous).unwrap(); - assert_eq!(*input, wb.as_slice()); + assert_eq!(*input, wb.into_slice()); } } diff --git a/matter/src/cert/printer.rs b/matter/src/cert/printer.rs index e92dbd46..b9336073 100644 --- a/matter/src/cert/printer.rs +++ b/matter/src/cert/printer.rs @@ -18,8 +18,8 @@ use super::{CertConsumer, MAX_DEPTH}; use crate::error::Error; use chrono::{TimeZone, Utc}; +use core::fmt; use log::warn; -use std::fmt; pub struct CertPrinter<'a, 'b> { level: usize, diff --git a/matter/src/codec/base38.rs b/matter/src/codec/base38.rs index d4c69c14..7b7e7587 100644 --- a/matter/src/codec/base38.rs +++ b/matter/src/codec/base38.rs @@ -17,6 +17,10 @@ //! Base38 encoding and decoding functions. +extern crate alloc; + +use alloc::{string::String, vec::Vec}; + use crate::error::Error; const BASE38_CHARS: [char; 38] = [ @@ -97,7 +101,7 @@ pub fn encode(bytes: &[u8], length: Option) -> String { while offset < length { let remaining = length - offset; match remaining.cmp(&2) { - std::cmp::Ordering::Greater => { + core::cmp::Ordering::Greater => { result.push_str(&encode_base38( ((bytes[offset + 2] as u32) << 16) | ((bytes[offset + 1] as u32) << 8) @@ -106,14 +110,14 @@ pub fn encode(bytes: &[u8], length: Option) -> String { )); offset += 3; } - std::cmp::Ordering::Equal => { + core::cmp::Ordering::Equal => { result.push_str(&encode_base38( ((bytes[offset + 1] as u32) << 8) | (bytes[offset] as u32), 4, )); break; } - std::cmp::Ordering::Less => { + core::cmp::Ordering::Less => { result.push_str(&encode_base38(bytes[offset] as u32, 2)); break; } diff --git a/matter/src/core.rs b/matter/src/core.rs index 9f1b13bc..7b853b96 100644 --- a/matter/src/core.rs +++ b/matter/src/core.rs @@ -15,21 +15,19 @@ * limitations under the License. */ +use core::{borrow::Borrow, cell::RefCell}; + use crate::{ acl::AclMgr, - data_model::{ - cluster_basic_information::BasicInfoConfig, core::DataModel, - sdm::dev_att::DevAttDataFetcher, - }, + data_model::{cluster_basic_information::BasicInfoConfig, sdm::failsafe::FailSafe}, error::*, fabric::FabricMgr, - interaction_model::InteractionModel, - mdns::Mdns, + mdns::{Mdns, MdnsMgr}, pairing::{print_pairing_code_and_qr, DiscoveryCapabilities}, - secure_channel::{core::SecureChannel, pake::PaseMgr, spake2p::VerifierData}, - transport, + secure_channel::{pake::PaseMgr, spake2p::VerifierData}, + transport::udp::MATTER_PORT, + utils::{epoch::Epoch, rand::Rand}, }; -use std::sync::Arc; /// Device Commissioning Data pub struct CommissioningData { @@ -40,13 +38,18 @@ pub struct CommissioningData { } /// The primary Matter Object -pub struct Matter { - transport_mgr: transport::mgr::Mgr, - data_model: DataModel, - fabric_mgr: Arc, +pub struct Matter<'a> { + pub fabric_mgr: RefCell, + pub acl_mgr: RefCell, + pub pase_mgr: RefCell, + pub failsafe: RefCell, + pub mdns_mgr: RefCell>, + pub epoch: Epoch, + pub rand: Rand, + pub dev_det: &'a BasicInfoConfig<'a>, } -impl Matter { +impl<'a> Matter<'a> { /// Creates a new Matter object /// /// # Parameters @@ -54,57 +57,87 @@ impl Matter { /// requires a set of device attestation certificates and keys. It is the responsibility of /// this object to return the device attestation details when queried upon. pub fn new( - dev_det: BasicInfoConfig, - dev_att: Box, - dev_comm: CommissioningData, - ) -> Result, Error> { - let mdns = Mdns::get()?; - mdns.set_values(dev_det.vid, dev_det.pid, &dev_det.device_name); - - let fabric_mgr = Arc::new(FabricMgr::new()?); - let open_comm_window = fabric_mgr.is_empty(); - if open_comm_window { - print_pairing_code_and_qr(&dev_det, &dev_comm, DiscoveryCapabilities::default()); + dev_det: &'a BasicInfoConfig, + mdns: &'a mut dyn Mdns, + epoch: Epoch, + rand: Rand, + ) -> Self { + Self { + fabric_mgr: RefCell::new(FabricMgr::new()), + acl_mgr: RefCell::new(AclMgr::new()), + pase_mgr: RefCell::new(PaseMgr::new(epoch, rand)), + failsafe: RefCell::new(FailSafe::new()), + mdns_mgr: RefCell::new(MdnsMgr::new( + dev_det.vid, + dev_det.pid, + dev_det.device_name, + MATTER_PORT, + mdns, + )), + epoch, + rand, + dev_det, } + } - let acl_mgr = Arc::new(AclMgr::new()?); - let mut pase = PaseMgr::new(); - let data_model = - DataModel::new(dev_det, dev_att, fabric_mgr.clone(), acl_mgr, pase.clone())?; - let mut matter = Box::new(Matter { - transport_mgr: transport::mgr::Mgr::new()?, - data_model, - fabric_mgr, - }); - let interaction_model = - Box::new(InteractionModel::new(Box::new(matter.data_model.clone()))); - matter.transport_mgr.register_protocol(interaction_model)?; + pub fn dev_det(&self) -> &BasicInfoConfig { + self.dev_det + } + pub fn start(&mut self, dev_comm: CommissioningData) -> Result<(), Error> { + let open_comm_window = self.fabric_mgr.borrow().is_empty(); if open_comm_window { - pase.enable_pase_session(dev_comm.verifier, dev_comm.discriminator)?; + print_pairing_code_and_qr(self.dev_det, &dev_comm, DiscoveryCapabilities::default()); + + self.pase_mgr.borrow_mut().enable_pase_session( + dev_comm.verifier, + dev_comm.discriminator, + &mut self.mdns_mgr.borrow_mut(), + )?; } - let secure_channel = Box::new(SecureChannel::new(pase, matter.fabric_mgr.clone())); - matter.transport_mgr.register_protocol(secure_channel)?; - Ok(matter) + Ok(()) } +} - /// Returns an Arc to [DataModel] - /// - /// The Data Model is where you express what is the type of your device. Typically - /// once you gets this reference, you acquire the write lock and add your device - /// types, clusters, attributes, commands to the data model. - pub fn get_data_model(&self) -> DataModel { - self.data_model.clone() +impl<'a> Borrow> for Matter<'a> { + fn borrow(&self) -> &RefCell { + &self.fabric_mgr } +} - /// Starts the Matter daemon - /// - /// This call does NOT return - /// - /// This call starts the Matter daemon that starts communication with other Matter - /// devices on the network. - pub fn start_daemon(&mut self) -> Result<(), Error> { - self.transport_mgr.start() +impl<'a> Borrow> for Matter<'a> { + fn borrow(&self) -> &RefCell { + &self.acl_mgr + } +} + +impl<'a> Borrow> for Matter<'a> { + fn borrow(&self) -> &RefCell { + &self.pase_mgr + } +} + +impl<'a> Borrow> for Matter<'a> { + fn borrow(&self) -> &RefCell { + &self.failsafe + } +} + +impl<'a> Borrow>> for Matter<'a> { + fn borrow(&self) -> &RefCell> { + &self.mdns_mgr + } +} + +impl<'a> Borrow for Matter<'a> { + fn borrow(&self) -> &Epoch { + &self.epoch + } +} + +impl<'a> Borrow for Matter<'a> { + fn borrow(&self) -> &Rand { + &self.rand } } diff --git a/matter/src/crypto/crypto_dummy.rs b/matter/src/crypto/crypto_dummy.rs index 80c12887..f193b205 100644 --- a/matter/src/crypto/crypto_dummy.rs +++ b/matter/src/crypto/crypto_dummy.rs @@ -19,39 +19,118 @@ use log::error; use crate::error::Error; -use super::CryptoKeyPair; +pub fn hkdf_sha256(_salt: &[u8], _ikm: &[u8], _info: &[u8], _key: &mut [u8]) -> Result<(), Error> { + error!("This API should never get called"); + Ok(()) +} -pub struct KeyPairDummy {} +#[derive(Clone)] +pub struct Sha256 {} -impl KeyPairDummy { +impl Sha256 { pub fn new() -> Result { Ok(Self {}) } + + pub fn update(&mut self, _data: &[u8]) -> Result<(), Error> { + Ok(()) + } + + pub fn finish(self, _digest: &mut [u8]) -> Result<(), Error> { + Ok(()) + } +} + +pub struct HmacSha256 {} + +impl HmacSha256 { + pub fn new(_key: &[u8]) -> Result { + error!("This API should never get called"); + Ok(Self {}) + } + + pub fn update(&mut self, _data: &[u8]) -> Result<(), Error> { + error!("This API should never get called"); + Ok(()) + } + + pub fn finish(self, _out: &mut [u8]) -> Result<(), Error> { + error!("This API should never get called"); + Ok(()) + } } -impl CryptoKeyPair for KeyPairDummy { - fn get_csr<'a>(&self, _out_csr: &'a mut [u8]) -> Result<&'a [u8], Error> { +pub struct KeyPair; + +impl KeyPair { + pub fn new() -> Result { + Ok(Self) + } + + pub fn new_from_components(_pub_key: &[u8], _priv_key: &[u8]) -> Result { + error!("This API should never get called"); + + Ok(Self {}) + } + + pub fn new_from_public(_pub_key: &[u8]) -> Result { + error!("This API should never get called"); + + Ok(Self {}) + } + + pub fn get_csr<'a>(&self, _out_csr: &'a mut [u8]) -> Result<&'a [u8], Error> { error!("This API should never get called"); Err(Error::Invalid) } - fn get_public_key(&self, _pub_key: &mut [u8]) -> Result { + + pub fn get_public_key(&self, _pub_key: &mut [u8]) -> Result { error!("This API should never get called"); Err(Error::Invalid) } - fn get_private_key(&self, _pub_key: &mut [u8]) -> Result { + + pub fn get_private_key(&self, _pub_key: &mut [u8]) -> Result { error!("This API should never get called"); Err(Error::Invalid) } - fn derive_secret(self, _peer_pub_key: &[u8], _secret: &mut [u8]) -> Result { + + pub fn derive_secret(self, _peer_pub_key: &[u8], _secret: &mut [u8]) -> Result { error!("This API should never get called"); Err(Error::Invalid) } - fn sign_msg(&self, _msg: &[u8], _signature: &mut [u8]) -> Result { + + pub fn sign_msg(&self, _msg: &[u8], _signature: &mut [u8]) -> Result { error!("This API should never get called"); Err(Error::Invalid) } - fn verify_msg(&self, _msg: &[u8], _signature: &[u8]) -> Result<(), Error> { + + pub fn verify_msg(&self, _msg: &[u8], _signature: &[u8]) -> Result<(), Error> { error!("This API should never get called"); Err(Error::Invalid) } } + +pub fn pbkdf2_hmac(_pass: &[u8], _iter: usize, _salt: &[u8], _key: &mut [u8]) -> Result<(), Error> { + error!("This API should never get called"); + + Ok(()) +} + +pub fn encrypt_in_place( + _key: &[u8], + _nonce: &[u8], + _ad: &[u8], + _data: &mut [u8], + _data_len: usize, +) -> Result { + Ok(0) +} + +pub fn decrypt_in_place( + _key: &[u8], + _nonce: &[u8], + _ad: &[u8], + _data: &mut [u8], +) -> Result { + Ok(0) +} diff --git a/matter/src/crypto/crypto_esp_mbedtls.rs b/matter/src/crypto/crypto_esp_mbedtls.rs index 9a8495d6..fe723370 100644 --- a/matter/src/crypto/crypto_esp_mbedtls.rs +++ b/matter/src/crypto/crypto_esp_mbedtls.rs @@ -19,8 +19,6 @@ use log::error; use crate::error::Error; -use super::CryptoKeyPair; - pub fn hkdf_sha256(_salt: &[u8], _ikm: &[u8], _info: &[u8], _key: &mut [u8]) -> Result<(), Error> { error!("This API should never get called"); Ok(()) @@ -82,26 +80,28 @@ impl KeyPair { Ok(Self {}) } -} -impl CryptoKeyPair for KeyPair { - fn get_csr<'a>(&self, _out_csr: &'a mut [u8]) -> Result<&'a [u8], Error> { + pub fn get_csr<'a>(&self, _out_csr: &'a mut [u8]) -> Result<&'a [u8], Error> { error!("This API should never get called"); Err(Error::Invalid) } - fn get_public_key(&self, _pub_key: &mut [u8]) -> Result { + + pub fn get_public_key(&self, _pub_key: &mut [u8]) -> Result { error!("This API should never get called"); Err(Error::Invalid) } - fn derive_secret(self, _peer_pub_key: &[u8], _secret: &mut [u8]) -> Result { + + pub fn derive_secret(self, _peer_pub_key: &[u8], _secret: &mut [u8]) -> Result { error!("This API should never get called"); Err(Error::Invalid) } - fn sign_msg(&self, _msg: &[u8], _signature: &mut [u8]) -> Result { + + pub fn sign_msg(&self, _msg: &[u8], _signature: &mut [u8]) -> Result { error!("This API should never get called"); Err(Error::Invalid) } - fn verify_msg(&self, _msg: &[u8], _signature: &[u8]) -> Result<(), Error> { + + pub fn verify_msg(&self, _msg: &[u8], _signature: &[u8]) -> Result<(), Error> { error!("This API should never get called"); Err(Error::Invalid) } diff --git a/matter/src/crypto/crypto_mbedtls.rs b/matter/src/crypto/crypto_mbedtls.rs index e8185835..2890fd19 100644 --- a/matter/src/crypto/crypto_mbedtls.rs +++ b/matter/src/crypto/crypto_mbedtls.rs @@ -15,9 +15,11 @@ * limitations under the License. */ -use std::sync::Arc; +extern crate alloc; -use log::error; +use alloc::sync::Arc; + +use log::{error, info}; use mbedtls::{ bignum::Mpi, cipher::{Authenticated, Cipher}, @@ -28,7 +30,6 @@ use mbedtls::{ x509, }; -use super::CryptoKeyPair; use crate::{ // TODO: We should move ASN1Writer out of Cert, // so Crypto doesn't have to depend on Cert @@ -85,10 +86,8 @@ impl KeyPair { key: Pk::public_from_ec_components(group, pub_key)?, }) } -} -impl CryptoKeyPair for KeyPair { - fn get_csr<'a>(&self, out_csr: &'a mut [u8]) -> Result<&'a [u8], Error> { + pub fn get_csr<'a>(&self, out_csr: &'a mut [u8]) -> Result<&'a [u8], Error> { let tmp_priv = self.key.ec_private()?; let mut tmp_key = Pk::private_from_ec_components(EcGroup::new(EcGroupId::SecP256R1)?, tmp_priv)?; @@ -112,7 +111,7 @@ impl CryptoKeyPair for KeyPair { } } - fn get_public_key(&self, pub_key: &mut [u8]) -> Result { + pub fn get_public_key(&self, pub_key: &mut [u8]) -> Result { let public_key = self.key.ec_public()?; let group = EcGroup::new(EcGroupId::SecP256R1)?; let vec = public_key.to_binary(&group, false)?; @@ -122,7 +121,7 @@ impl CryptoKeyPair for KeyPair { Ok(len) } - fn get_private_key(&self, priv_key: &mut [u8]) -> Result { + pub fn get_private_key(&self, priv_key: &mut [u8]) -> Result { let priv_key_mpi = self.key.ec_private()?; let vec = priv_key_mpi.to_binary()?; @@ -131,7 +130,7 @@ impl CryptoKeyPair for KeyPair { Ok(len) } - fn derive_secret(self, peer_pub_key: &[u8], secret: &mut [u8]) -> Result { + pub fn derive_secret(self, peer_pub_key: &[u8], secret: &mut [u8]) -> Result { // mbedtls requires a 'mut' key. Instead of making a change in our Trait, // we just clone the key this way @@ -149,7 +148,7 @@ impl CryptoKeyPair for KeyPair { Ok(len) } - fn sign_msg(&self, msg: &[u8], signature: &mut [u8]) -> Result { + pub fn sign_msg(&self, msg: &[u8], signature: &mut [u8]) -> Result { // mbedtls requires a 'mut' key. Instead of making a change in our Trait, // we just clone the key this way let tmp_key = self.key.ec_private()?; @@ -175,7 +174,7 @@ impl CryptoKeyPair for KeyPair { Ok(len) } - fn verify_msg(&self, msg: &[u8], signature: &[u8]) -> Result<(), Error> { + pub fn verify_msg(&self, msg: &[u8], signature: &[u8]) -> Result<(), Error> { // mbedtls requires a 'mut' key. Instead of making a change in our Trait, // we just clone the key this way let tmp_key = self.key.ec_public()?; @@ -192,7 +191,7 @@ impl CryptoKeyPair for KeyPair { let mbedtls_sign = &mbedtls_sign[..len]; if let Err(e) = tmp_key.verify(hash::Type::Sha256, &msg_hash, mbedtls_sign) { - println!("The error is {}", e); + info!("The error is {}", e); Err(Error::InvalidSignature) } else { Ok(()) diff --git a/matter/src/crypto/crypto_openssl.rs b/matter/src/crypto/crypto_openssl.rs index e8b67c12..e4486192 100644 --- a/matter/src/crypto/crypto_openssl.rs +++ b/matter/src/crypto/crypto_openssl.rs @@ -17,7 +17,6 @@ use crate::error::Error; -use super::CryptoKeyPair; use foreign_types::ForeignTypeRef; use log::error; use openssl::asn1::Asn1Type; @@ -112,10 +111,8 @@ impl KeyPair { KeyType::Private(k) => Ok(&k), } } -} -impl CryptoKeyPair for KeyPair { - fn get_public_key(&self, pub_key: &mut [u8]) -> Result { + pub fn get_public_key(&self, pub_key: &mut [u8]) -> Result { let group = EcGroup::from_curve_name(Nid::X9_62_PRIME256V1)?; let mut bn_ctx = BigNumContext::new()?; let s = self.public_key_point().to_bytes( @@ -128,14 +125,14 @@ impl CryptoKeyPair for KeyPair { Ok(len) } - fn get_private_key(&self, priv_key: &mut [u8]) -> Result { + pub fn get_private_key(&self, priv_key: &mut [u8]) -> Result { let s = self.private_key()?.private_key().to_vec(); let len = s.len(); priv_key[..len].copy_from_slice(s.as_slice()); Ok(len) } - fn derive_secret(self, peer_pub_key: &[u8], secret: &mut [u8]) -> Result { + pub fn derive_secret(self, peer_pub_key: &[u8], secret: &mut [u8]) -> Result { let self_pkey = PKey::from_ec_key(self.private_key()?.clone())?; let group = EcGroup::from_curve_name(Nid::X9_62_PRIME256V1)?; @@ -149,7 +146,7 @@ impl CryptoKeyPair for KeyPair { Ok(deriver.derive(secret)?) } - fn get_csr<'a>(&self, out_csr: &'a mut [u8]) -> Result<&'a [u8], Error> { + pub fn get_csr<'a>(&self, out_csr: &'a mut [u8]) -> Result<&'a [u8], Error> { let mut builder = X509ReqBuilder::new()?; builder.set_version(0)?; @@ -174,7 +171,7 @@ impl CryptoKeyPair for KeyPair { } } - fn sign_msg(&self, msg: &[u8], signature: &mut [u8]) -> Result { + pub fn sign_msg(&self, msg: &[u8], signature: &mut [u8]) -> Result { // First get the SHA256 of the message let mut h = Hasher::new(MessageDigest::sha256())?; h.update(msg)?; @@ -193,7 +190,7 @@ impl CryptoKeyPair for KeyPair { Ok(64) } - fn verify_msg(&self, msg: &[u8], signature: &[u8]) -> Result<(), Error> { + pub fn verify_msg(&self, msg: &[u8], signature: &[u8]) -> Result<(), Error> { // First get the SHA256 of the message let mut h = Hasher::new(MessageDigest::sha256())?; h.update(msg)?; diff --git a/matter/src/crypto/mod.rs b/matter/src/crypto/mod.rs index 2473fb0d..5c73ff2d 100644 --- a/matter/src/crypto/mod.rs +++ b/matter/src/crypto/mod.rs @@ -15,8 +15,6 @@ * limitations under the License. */ -use crate::error::Error; - pub const SYMM_KEY_LEN_BITS: usize = 128; pub const SYMM_KEY_LEN_BYTES: usize = SYMM_KEY_LEN_BITS / 8; @@ -35,16 +33,6 @@ pub const ECDH_SHARED_SECRET_LEN_BYTES: usize = 32; pub const EC_SIGNATURE_LEN_BYTES: usize = 64; -// APIs particular to a KeyPair so a KeyPair object can be defined -pub trait CryptoKeyPair { - fn get_csr<'a>(&self, csr: &'a mut [u8]) -> Result<&'a [u8], Error>; - fn get_public_key(&self, pub_key: &mut [u8]) -> Result; - fn get_private_key(&self, priv_key: &mut [u8]) -> Result; - fn derive_secret(self, peer_pub_key: &[u8], secret: &mut [u8]) -> Result; - fn sign_msg(&self, msg: &[u8], signature: &mut [u8]) -> Result; - fn verify_msg(&self, msg: &[u8], signature: &[u8]) -> Result<(), Error>; -} - #[cfg(feature = "crypto_esp_mbedtls")] mod crypto_esp_mbedtls; #[cfg(feature = "crypto_esp_mbedtls")] @@ -65,13 +53,26 @@ mod crypto_rustcrypto; #[cfg(feature = "crypto_rustcrypto")] pub use self::crypto_rustcrypto::*; +#[cfg(not(any( + feature = "crypto_openssl", + feature = "crypto_mbedtls", + feature = "crypto_esp_mbedtls", + feature = "crypto_rustcrypto" +)))] pub mod crypto_dummy; +#[cfg(not(any( + feature = "crypto_openssl", + feature = "crypto_mbedtls", + feature = "crypto_esp_mbedtls", + feature = "crypto_rustcrypto" +)))] +pub use self::crypto_dummy::*; #[cfg(test)] mod tests { use crate::error::Error; - use super::{CryptoKeyPair, KeyPair}; + use super::KeyPair; #[test] fn test_verify_msg_success() { diff --git a/matter/src/data_model/cluster_basic_information.rs b/matter/src/data_model/cluster_basic_information.rs index 7c9cada5..71c07229 100644 --- a/matter/src/data_model/cluster_basic_information.rs +++ b/matter/src/data_model/cluster_basic_information.rs @@ -15,100 +15,129 @@ * limitations under the License. */ +use core::convert::TryInto; + use super::objects::*; -use crate::error::*; -use num_derive::FromPrimitive; +use crate::{attribute_enum, error::Error, utils::rand::Rand}; +use strum::{EnumDiscriminants, FromRepr}; pub const ID: u32 = 0x0028; -#[derive(FromPrimitive)] +#[derive(Clone, Copy, Debug, FromRepr, EnumDiscriminants)] +#[repr(u16)] pub enum Attributes { - DMRevision = 0, - VendorId = 2, - ProductId = 4, - HwVer = 7, - SwVer = 9, - SwVerString = 0xa, - SerialNo = 0x0f, + DMRevision(AttrType) = 0, + VendorId(AttrType) = 2, + ProductId(AttrType) = 4, + HwVer(AttrType) = 7, + SwVer(AttrType) = 9, + SwVerString(AttrUtfType) = 0xa, + SerialNo(AttrUtfType) = 0x0f, } +attribute_enum!(Attributes); + #[derive(Default)] -pub struct BasicInfoConfig { +pub struct BasicInfoConfig<'a> { pub vid: u16, pub pid: u16, pub hw_ver: u16, pub sw_ver: u32, - pub sw_ver_str: String, - pub serial_no: String, + pub sw_ver_str: &'a str, + pub serial_no: &'a str, /// Device name; up to 32 characters - pub device_name: String, + pub device_name: &'a str, } -pub struct BasicInfoCluster { - base: Cluster, -} +pub const CLUSTER: Cluster<'static> = Cluster { + id: ID as _, + feature_map: 0, + attributes: &[ + FEATURE_MAP, + ATTRIBUTE_LIST, + Attribute::new( + AttributesDiscriminants::DMRevision as u16, + Access::RV, + Quality::FIXED, + ), + Attribute::new( + AttributesDiscriminants::VendorId as u16, + Access::RV, + Quality::FIXED, + ), + Attribute::new( + AttributesDiscriminants::ProductId as u16, + Access::RV, + Quality::FIXED, + ), + Attribute::new( + AttributesDiscriminants::HwVer as u16, + Access::RV, + Quality::FIXED, + ), + Attribute::new( + AttributesDiscriminants::SwVer as u16, + Access::RV, + Quality::FIXED, + ), + Attribute::new( + AttributesDiscriminants::SwVerString as u16, + Access::RV, + Quality::FIXED, + ), + Attribute::new( + AttributesDiscriminants::SerialNo as u16, + Access::RV, + Quality::FIXED, + ), + ], + commands: &[], +}; -impl BasicInfoCluster { - pub fn new(cfg: BasicInfoConfig) -> Result, Error> { - let mut cluster = Box::new(BasicInfoCluster { - base: Cluster::new(ID)?, - }); +pub struct BasicInfoCluster<'a> { + data_ver: Dataver, + cfg: &'a BasicInfoConfig<'a>, +} - let attrs = [ - Attribute::new( - Attributes::DMRevision as u16, - AttrValue::Uint8(1), - Access::RV, - Quality::FIXED, - ), - Attribute::new( - Attributes::VendorId as u16, - AttrValue::Uint16(cfg.vid), - Access::RV, - Quality::FIXED, - ), - Attribute::new( - Attributes::ProductId as u16, - AttrValue::Uint16(cfg.pid), - Access::RV, - Quality::FIXED, - ), - Attribute::new( - Attributes::HwVer as u16, - AttrValue::Uint16(cfg.hw_ver), - Access::RV, - Quality::FIXED, - ), - Attribute::new( - Attributes::SwVer as u16, - AttrValue::Uint32(cfg.sw_ver), - Access::RV, - Quality::FIXED, - ), - Attribute::new( - Attributes::SwVerString as u16, - AttrValue::Utf8(cfg.sw_ver_str), - Access::RV, - Quality::FIXED, - ), - Attribute::new( - Attributes::SerialNo as u16, - AttrValue::Utf8(cfg.serial_no), - Access::RV, - Quality::FIXED, - ), - ]; - cluster.base.add_attributes(&attrs[..])?; +impl<'a> BasicInfoCluster<'a> { + pub fn new(cfg: &'a BasicInfoConfig<'a>, rand: Rand) -> Self { + Self { + data_ver: Dataver::new(rand), + cfg, + } + } - Ok(cluster) + pub fn read(&self, attr: &AttrDetails, encoder: AttrDataEncoder) -> Result<(), Error> { + if let Some(writer) = encoder.with_dataver(self.data_ver.get())? { + if attr.is_system() { + CLUSTER.read(attr.attr_id, writer) + } else { + match attr.attr_id.try_into()? { + Attributes::DMRevision(codec) => codec.encode(writer, 1), + Attributes::VendorId(codec) => codec.encode(writer, self.cfg.vid), + Attributes::ProductId(codec) => codec.encode(writer, self.cfg.pid), + Attributes::HwVer(codec) => codec.encode(writer, self.cfg.hw_ver), + Attributes::SwVer(codec) => codec.encode(writer, self.cfg.sw_ver), + Attributes::SwVerString(codec) => codec.encode(writer, self.cfg.sw_ver_str), + Attributes::SerialNo(codec) => codec.encode(writer, self.cfg.serial_no), + } + } + } else { + Ok(()) + } } } -impl ClusterType for BasicInfoCluster { - fn base(&self) -> &Cluster { - &self.base +impl<'a> Handler for BasicInfoCluster<'a> { + fn read(&self, attr: &AttrDetails, encoder: AttrDataEncoder) -> Result<(), Error> { + BasicInfoCluster::read(self, attr, encoder) } - fn base_mut(&mut self) -> &mut Cluster { - &mut self.base +} + +impl<'a> NonBlockingHandler for BasicInfoCluster<'a> {} + +impl<'a> ChangeNotifier<()> for BasicInfoCluster<'a> { + fn consume_change(&mut self) -> Option<()> { + self.data_ver.consume_change(()) } } diff --git a/matter/src/data_model/cluster_on_off.rs b/matter/src/data_model/cluster_on_off.rs index 6864a686..9b173673 100644 --- a/matter/src/data_model/cluster_on_off.rs +++ b/matter/src/data_model/cluster_on_off.rs @@ -15,114 +15,153 @@ * limitations under the License. */ +use core::convert::TryInto; + use super::objects::*; use crate::{ - cmd_enter, - error::*, - interaction_model::{command::CommandReq, core::IMStatusCode}, + attribute_enum, cmd_enter, command_enum, error::Error, interaction_model::core::Transaction, + tlv::TLVElement, utils::rand::Rand, }; use log::info; -use num_derive::FromPrimitive; +use strum::{EnumDiscriminants, FromRepr}; pub const ID: u32 = 0x0006; +#[derive(FromRepr, EnumDiscriminants)] +#[repr(u16)] pub enum Attributes { - OnOff = 0x0, + OnOff(AttrType) = 0x0, } -#[derive(FromPrimitive)] +attribute_enum!(Attributes); + +#[derive(FromRepr, EnumDiscriminants)] +#[repr(u32)] pub enum Commands { Off = 0x0, On = 0x01, Toggle = 0x02, } -fn attr_on_off_new() -> Attribute { - // OnOff, Value: false - Attribute::new( - Attributes::OnOff as u16, - AttrValue::Bool(false), - Access::RV, - Quality::PERSISTENT, - ) -} +command_enum!(Commands); + +pub const CLUSTER: Cluster<'static> = Cluster { + id: ID as _, + feature_map: 0, + attributes: &[ + FEATURE_MAP, + ATTRIBUTE_LIST, + Attribute::new( + AttributesDiscriminants::OnOff as u16, + Access::RV, + Quality::PERSISTENT, + ), + ], + commands: &[ + CommandsDiscriminants::Off as _, + CommandsDiscriminants::On as _, + CommandsDiscriminants::Toggle as _, + ], +}; pub struct OnOffCluster { - base: Cluster, + data_ver: Dataver, + on: bool, } impl OnOffCluster { - pub fn new() -> Result, Error> { - let mut cluster = Box::new(OnOffCluster { - base: Cluster::new(ID)?, - }); - cluster.base.add_attribute(attr_on_off_new())?; - Ok(cluster) + pub fn new(rand: Rand) -> Self { + Self { + data_ver: Dataver::new(rand), + on: false, + } } -} -impl ClusterType for OnOffCluster { - fn base(&self) -> &Cluster { - &self.base + pub fn set(&mut self, on: bool) { + if self.on != on { + self.on = on; + self.data_ver.changed(); + } } - fn base_mut(&mut self) -> &mut Cluster { - &mut self.base + + pub fn read(&self, attr: &AttrDetails, encoder: AttrDataEncoder) -> Result<(), Error> { + if let Some(writer) = encoder.with_dataver(self.data_ver.get())? { + if attr.is_system() { + CLUSTER.read(attr.attr_id, writer) + } else { + match attr.attr_id.try_into()? { + Attributes::OnOff(codec) => codec.encode(writer, self.on), + } + } + } else { + Ok(()) + } + } + + pub fn write(&mut self, attr: &AttrDetails, data: AttrData) -> Result<(), Error> { + let data = data.with_dataver(self.data_ver.get())?; + + match attr.attr_id.try_into()? { + Attributes::OnOff(codec) => self.set(codec.decode(data)?), + } + + self.data_ver.changed(); + + Ok(()) } - fn handle_command(&mut self, cmd_req: &mut CommandReq) -> Result<(), IMStatusCode> { - let cmd = cmd_req - .cmd - .path - .leaf - .map(num::FromPrimitive::from_u32) - .ok_or(IMStatusCode::UnsupportedCommand)? - .ok_or(IMStatusCode::UnsupportedCommand)?; - match cmd { + pub fn invoke( + &mut self, + cmd: &CmdDetails, + _data: &TLVElement, + _encoder: CmdDataEncoder, + ) -> Result<(), Error> { + match cmd.cmd_id.try_into()? { Commands::Off => { cmd_enter!("Off"); - let value = self - .base - .read_attribute_raw(Attributes::OnOff as u16) - .unwrap(); - if AttrValue::Bool(true) == *value { - self.base - .write_attribute_raw(Attributes::OnOff as u16, AttrValue::Bool(false)) - .map_err(|_| IMStatusCode::Failure)?; - } - cmd_req.trans.complete(); - Err(IMStatusCode::Success) + self.set(false); } Commands::On => { cmd_enter!("On"); - let value = self - .base - .read_attribute_raw(Attributes::OnOff as u16) - .unwrap(); - if AttrValue::Bool(false) == *value { - self.base - .write_attribute_raw(Attributes::OnOff as u16, AttrValue::Bool(true)) - .map_err(|_| IMStatusCode::Failure)?; - } - - cmd_req.trans.complete(); - Err(IMStatusCode::Success) + self.set(true); } Commands::Toggle => { cmd_enter!("Toggle"); - let value = match self - .base - .read_attribute_raw(Attributes::OnOff as u16) - .unwrap() - { - &AttrValue::Bool(v) => v, - _ => false, - }; - self.base - .write_attribute_raw(Attributes::OnOff as u16, AttrValue::Bool(!value)) - .map_err(|_| IMStatusCode::Failure)?; - cmd_req.trans.complete(); - Err(IMStatusCode::Success) + self.set(!self.on); } } + + self.data_ver.changed(); + + Ok(()) + } +} + +impl Handler for OnOffCluster { + fn read(&self, attr: &AttrDetails, encoder: AttrDataEncoder) -> Result<(), Error> { + OnOffCluster::read(self, attr, encoder) + } + + fn write(&mut self, attr: &AttrDetails, data: AttrData) -> Result<(), Error> { + OnOffCluster::write(self, attr, data) + } + + fn invoke( + &mut self, + _transaction: &mut Transaction, + cmd: &CmdDetails, + data: &TLVElement, + encoder: CmdDataEncoder, + ) -> Result<(), Error> { + OnOffCluster::invoke(self, cmd, data, encoder) + } +} + +// TODO: Might be removed once the `on` member is externalized +impl NonBlockingHandler for OnOffCluster {} + +impl ChangeNotifier<()> for OnOffCluster { + fn consume_change(&mut self) -> Option<()> { + self.data_ver.consume_change(()) } } diff --git a/matter/src/data_model/cluster_template.rs b/matter/src/data_model/cluster_template.rs index 6555e9e8..c103812f 100644 --- a/matter/src/data_model/cluster_template.rs +++ b/matter/src/data_model/cluster_template.rs @@ -16,29 +16,59 @@ */ use crate::{ - data_model::objects::{Cluster, ClusterType}, + data_model::objects::{Cluster, Handler}, error::Error, + utils::rand::Rand, +}; + +use super::objects::{ + AttrDataEncoder, AttrDetails, ChangeNotifier, Dataver, NonBlockingHandler, ATTRIBUTE_LIST, + FEATURE_MAP, }; const CLUSTER_NETWORK_COMMISSIONING_ID: u32 = 0x0031; +pub const CLUSTER: Cluster<'static> = Cluster { + id: CLUSTER_NETWORK_COMMISSIONING_ID as _, + feature_map: 0, + attributes: &[FEATURE_MAP, ATTRIBUTE_LIST], + commands: &[], +}; + pub struct TemplateCluster { - base: Cluster, + data_ver: Dataver, } -impl ClusterType for TemplateCluster { - fn base(&self) -> &Cluster { - &self.base +impl TemplateCluster { + pub fn new(rand: Rand) -> Self { + Self { + data_ver: Dataver::new(rand), + } } - fn base_mut(&mut self) -> &mut Cluster { - &mut self.base + + pub fn read(&self, attr: &AttrDetails, encoder: AttrDataEncoder) -> Result<(), Error> { + if let Some(writer) = encoder.with_dataver(self.data_ver.get())? { + if attr.is_system() { + CLUSTER.read(attr.attr_id, writer) + } else { + Err(Error::AttributeNotFound) + } + } else { + Ok(()) + } } } -impl TemplateCluster { - pub fn new() -> Result, Error> { - Ok(Box::new(Self { - base: Cluster::new(CLUSTER_NETWORK_COMMISSIONING_ID)?, - })) +impl Handler for TemplateCluster { + fn read(&self, attr: &AttrDetails, encoder: AttrDataEncoder) -> Result<(), Error> { + TemplateCluster::read(self, attr, encoder) + } +} + +impl NonBlockingHandler for TemplateCluster {} + +impl ChangeNotifier<()> for TemplateCluster { + fn consume_change(&mut self) -> Option<()> { + self.data_ver.consume_change(()) } } diff --git a/matter/src/data_model/core.rs b/matter/src/data_model/core.rs new file mode 100644 index 00000000..005871d5 --- /dev/null +++ b/matter/src/data_model/core.rs @@ -0,0 +1,199 @@ +/* + * + * Copyright (c) 2020-2022 Project CHIP Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +use core::cell::RefCell; + +use super::objects::*; +use crate::{ + acl::{Accessor, AclMgr}, + error::*, + interaction_model::core::{Interaction, Transaction}, + tlv::TLVWriter, + transport::packet::Packet, +}; + +pub struct DataModel<'a, T> { + pub acl_mgr: &'a RefCell, + pub node: &'a Node<'a>, + pub handler: T, +} + +impl<'a, T> DataModel<'a, T> { + pub const fn new(acl_mgr: &'a RefCell, node: &'a Node<'a>, handler: T) -> Self { + Self { + acl_mgr, + node, + handler, + } + } + + pub fn handle( + &mut self, + interaction: &Interaction, + tx: &mut Packet, + transaction: &mut Transaction, + ) -> Result + where + T: Handler, + { + let accessor = Accessor::for_session(transaction.session(), self.acl_mgr); + let mut tw = TLVWriter::new(tx.get_writebuf()?); + + match interaction { + Interaction::Read(req) => { + for item in self.node.read(req, &accessor) { + AttrDataEncoder::handle_read(item, &self.handler, &mut tw)?; + } + } + Interaction::Write(req) => { + for item in self.node.write(req, &accessor) { + AttrDataEncoder::handle_write(item, &mut self.handler, &mut tw)?; + } + } + Interaction::Invoke(req) => { + for item in self.node.invoke(req, &accessor) { + CmdDataEncoder::handle(item, &mut self.handler, transaction, &mut tw)?; + } + } + Interaction::Timed(_) => (), + } + + interaction.complete_tx(tx, transaction) + } + + #[cfg(feature = "nightly")] + pub async fn handle_async<'p>( + &mut self, + interaction: &Interaction<'_>, + tx: &'p mut Packet<'_>, + transaction: &mut Transaction<'_, '_>, + ) -> Result, Error> + where + T: super::objects::asynch::AsyncHandler, + { + let accessor = Accessor::for_session(transaction.session(), self.acl_mgr); + let mut tw = TLVWriter::new(tx.get_writebuf()?); + + match interaction { + Interaction::Read(req) => { + for item in self.node.read(req, &accessor) { + AttrDataEncoder::handle_read_async(item, &self.handler, &mut tw).await?; + } + } + Interaction::Write(req) => { + for item in self.node.write(req, &accessor) { + AttrDataEncoder::handle_write_async(item, &mut self.handler, &mut tw).await?; + } + } + Interaction::Invoke(req) => { + for item in self.node.invoke(req, &accessor) { + CmdDataEncoder::handle_async(item, &mut self.handler, transaction, &mut tw) + .await?; + } + } + Interaction::Timed(_) => (), + } + + interaction.complete_tx(tx, transaction) + } +} + +pub trait DataHandler { + fn handle( + &mut self, + interaction: &Interaction, + tx: &mut Packet, + transaction: &mut Transaction, + ) -> Result; +} + +impl DataHandler for &mut T +where + T: DataHandler, +{ + fn handle( + &mut self, + interaction: &Interaction, + tx: &mut Packet, + transaction: &mut Transaction, + ) -> Result { + (**self).handle(interaction, tx, transaction) + } +} + +impl<'a, T> DataHandler for DataModel<'a, T> +where + T: Handler, +{ + fn handle( + &mut self, + interaction: &Interaction, + tx: &mut Packet, + transaction: &mut Transaction, + ) -> Result { + DataModel::handle(self, interaction, tx, transaction) + } +} + +#[cfg(feature = "nightly")] +pub mod asynch { + use crate::{ + data_model::objects::asynch::AsyncHandler, + error::Error, + interaction_model::core::{Interaction, Transaction}, + transport::packet::Packet, + }; + + use super::DataModel; + + pub trait AsyncDataHandler { + async fn handle<'p>( + &mut self, + interaction: &Interaction, + tx: &'p mut Packet, + transaction: &mut Transaction, + ) -> Result, Error>; + } + + impl AsyncDataHandler for &mut T + where + T: AsyncDataHandler, + { + async fn handle<'p>( + &mut self, + interaction: &Interaction<'_>, + tx: &'p mut Packet<'_>, + transaction: &mut Transaction<'_, '_>, + ) -> Result, Error> { + (**self).handle(interaction, tx, transaction).await + } + } + + impl<'a, T> AsyncDataHandler for DataModel<'a, T> + where + T: AsyncHandler, + { + async fn handle<'p>( + &mut self, + interaction: &Interaction<'_>, + tx: &'p mut Packet<'_>, + transaction: &mut Transaction<'_, '_>, + ) -> Result, Error> { + DataModel::handle_async(self, interaction, tx, transaction).await + } + } +} diff --git a/matter/src/data_model/core/mod.rs b/matter/src/data_model/core/mod.rs deleted file mode 100644 index 4386cab6..00000000 --- a/matter/src/data_model/core/mod.rs +++ /dev/null @@ -1,394 +0,0 @@ -/* - * - * Copyright (c) 2020-2022 Project CHIP Authors - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -use self::subscribe::SubsCtx; - -use super::{ - cluster_basic_information::BasicInfoConfig, - device_types::device_type_add_root_node, - objects::{self, *}, - sdm::dev_att::DevAttDataFetcher, - system_model::descriptor::DescriptorCluster, -}; -use crate::{ - acl::{AccessReq, Accessor, AccessorSubjects, AclMgr, AuthMode}, - error::*, - fabric::FabricMgr, - interaction_model::{ - command::CommandReq, - core::{IMStatusCode, OpCode}, - messages::{ - ib::{self, AttrData, DataVersionFilter}, - msg::{self, InvReq, ReadReq, WriteReq}, - GenericPath, - }, - InteractionConsumer, Transaction, - }, - secure_channel::pake::PaseMgr, - tlv::{self, FromTLV, TLVArray, TLVWriter, TagType, ToTLV}, - transport::{ - proto_demux::ResponseRequired, - session::{Session, SessionMode}, - }, -}; -use log::{error, info}; -use std::sync::{Arc, RwLock}; - -#[derive(Clone)] -pub struct DataModel { - pub node: Arc>>, - acl_mgr: Arc, -} - -impl DataModel { - pub fn new( - dev_details: BasicInfoConfig, - dev_att: Box, - fabric_mgr: Arc, - acl_mgr: Arc, - pase_mgr: PaseMgr, - ) -> Result { - let dm = DataModel { - node: Arc::new(RwLock::new(Node::new()?)), - acl_mgr: acl_mgr.clone(), - }; - { - let mut node = dm.node.write()?; - node.set_changes_cb(Box::new(dm.clone())); - device_type_add_root_node( - &mut node, - dev_details, - dev_att, - fabric_mgr, - acl_mgr, - pase_mgr, - )?; - } - Ok(dm) - } - - // Encode a write attribute from a path that may or may not be wildcard - fn handle_write_attr_path( - node: &mut Node, - accessor: &Accessor, - attr_data: &AttrData, - tw: &mut TLVWriter, - ) { - let gen_path = attr_data.path.to_gp(); - let mut encoder = AttrWriteEncoder::new(tw, TagType::Anonymous); - encoder.set_path(gen_path); - - // The unsupported pieces of the wildcard path - if attr_data.path.cluster.is_none() { - encoder.encode_status(IMStatusCode::UnsupportedCluster, 0); - return; - } - if attr_data.path.attr.is_none() { - encoder.encode_status(IMStatusCode::UnsupportedAttribute, 0); - return; - } - - // Get the data - let write_data = match &attr_data.data { - EncodeValue::Closure(_) | EncodeValue::Value(_) => { - error!("Not supported"); - return; - } - EncodeValue::Tlv(t) => t, - }; - - if gen_path.is_wildcard() { - // This is a wildcard path, skip error - // This is required because there could be access control errors too that need - // to be taken care of. - encoder.skip_error(); - } - let mut attr = AttrDetails { - // will be udpated in the loop below - attr_id: 0, - list_index: attr_data.path.list_index, - fab_filter: false, - fab_idx: accessor.fab_idx, - }; - - let result = node.for_each_cluster_mut(&gen_path, |path, c| { - if attr_data.data_ver.is_some() && Some(c.base().get_dataver()) != attr_data.data_ver { - encoder.encode_status(IMStatusCode::DataVersionMismatch, 0); - return Ok(()); - } - - attr.attr_id = path.leaf.unwrap_or_default() as u16; - encoder.set_path(*path); - let mut access_req = AccessReq::new(accessor, path, Access::WRITE); - let r = match Cluster::write_attribute(c, &mut access_req, write_data, &attr) { - Ok(_) => IMStatusCode::Success, - Err(e) => e, - }; - encoder.encode_status(r, 0); - Ok(()) - }); - if let Err(e) = result { - // We hit this only if this is a non-wildcard path and some parts of the path are missing - encoder.encode_status(e, 0); - } - } - - // Handle command from a path that may or may not be wildcard - fn handle_command_path(node: &mut Node, cmd_req: &mut CommandReq) { - let wildcard = cmd_req.cmd.path.is_wildcard(); - let path = cmd_req.cmd.path; - - let result = node.for_each_cluster_mut(&path, |path, c| { - cmd_req.cmd.path = *path; - let result = c.handle_command(cmd_req); - if let Err(e) = result { - // It is likely that we might have to do an 'Access' aware traversal - // if there are other conditions in the wildcard scenario that shouldn't be - // encoded as CmdStatus - if !(wildcard && e == IMStatusCode::UnsupportedCommand) { - let invoke_resp = ib::InvResp::status_new(cmd_req.cmd, e, 0); - let _ = invoke_resp.to_tlv(cmd_req.resp, TagType::Anonymous); - } - } - Ok(()) - }); - if !wildcard { - if let Err(e) = result { - // We hit this only if this is a non-wildcard path - let invoke_resp = ib::InvResp::status_new(cmd_req.cmd, e, 0); - let _ = invoke_resp.to_tlv(cmd_req.resp, TagType::Anonymous); - } - } - } - - fn sess_to_accessor(&self, sess: &Session) -> Accessor { - match sess.get_session_mode() { - SessionMode::Case(c) => { - let mut subject = - AccessorSubjects::new(sess.get_peer_node_id().unwrap_or_default()); - for i in c.cat_ids { - if i != 0 { - let _ = subject.add_catid(i); - } - } - Accessor::new(c.fab_idx, subject, AuthMode::Case, self.acl_mgr.clone()) - } - SessionMode::Pase => Accessor::new( - 0, - AccessorSubjects::new(1), - AuthMode::Pase, - self.acl_mgr.clone(), - ), - - SessionMode::PlainText => Accessor::new( - 0, - AccessorSubjects::new(1), - AuthMode::Invalid, - self.acl_mgr.clone(), - ), - } - } - - /// Returns true if the path matches the cluster path and the data version is a match - fn data_filter_matches( - filters: &Option<&TLVArray>, - path: &GenericPath, - data_ver: u32, - ) -> bool { - if let Some(filters) = *filters { - for filter in filters.iter() { - // TODO: No handling of 'node' comparision yet - if Some(filter.path.endpoint) == path.endpoint - && Some(filter.path.cluster) == path.cluster - && filter.data_ver == data_ver - { - return true; - } - } - } - false - } -} - -pub mod read; -pub mod subscribe; - -/// Type of Resume Request -enum ResumeReq { - Subscribe(subscribe::SubsCtx), - Read(read::ResumeReadReq), -} - -impl objects::ChangeConsumer for DataModel { - fn endpoint_added(&self, id: EndptId, endpoint: &mut Endpoint) -> Result<(), Error> { - endpoint.add_cluster(DescriptorCluster::new(id, self.clone())?)?; - Ok(()) - } -} - -impl InteractionConsumer for DataModel { - fn consume_write_attr( - &self, - write_req: &WriteReq, - trans: &mut Transaction, - tw: &mut TLVWriter, - ) -> Result<(), Error> { - let accessor = self.sess_to_accessor(trans.session); - - tw.start_array(TagType::Context(msg::WriteRespTag::WriteResponses as u8))?; - let mut node = self.node.write().unwrap(); - for attr_data in write_req.write_requests.iter() { - DataModel::handle_write_attr_path(&mut node, &accessor, &attr_data, tw); - } - tw.end_container()?; - - Ok(()) - } - - fn consume_read_attr( - &self, - rx_buf: &[u8], - trans: &mut Transaction, - tw: &mut TLVWriter, - ) -> Result<(), Error> { - let mut resume_from = None; - let root = tlv::get_root_node(rx_buf)?; - let req = ReadReq::from_tlv(&root)?; - self.handle_read_req(&req, trans, tw, &mut resume_from)?; - if resume_from.is_some() { - // This is a multi-hop read transaction, remember this read request - let resume = read::ResumeReadReq::new(rx_buf, &resume_from)?; - if !trans.exch.is_data_none() { - error!("Exchange data already set, and multi-hop read"); - return Err(Error::InvalidState); - } - trans.exch.set_data_boxed(Box::new(ResumeReq::Read(resume))); - } - Ok(()) - } - - fn consume_invoke_cmd( - &self, - inv_req_msg: &InvReq, - trans: &mut Transaction, - tw: &mut TLVWriter, - ) -> Result<(), Error> { - let mut node = self.node.write().unwrap(); - if let Some(inv_requests) = &inv_req_msg.inv_requests { - // Array of InvokeResponse IBs - tw.start_array(TagType::Context(msg::InvRespTag::InvokeResponses as u8))?; - for i in inv_requests.iter() { - let data = if let Some(data) = i.data.unwrap_tlv() { - data - } else { - continue; - }; - info!("Invoke Commmand Handler executing: {:?}", i.path); - let mut cmd_req = CommandReq { - cmd: i.path, - data, - trans, - resp: tw, - }; - DataModel::handle_command_path(&mut node, &mut cmd_req); - } - tw.end_container()?; - } - - Ok(()) - } - - fn consume_status_report( - &self, - req: &msg::StatusResp, - trans: &mut Transaction, - tw: &mut TLVWriter, - ) -> Result<(OpCode, ResponseRequired), Error> { - if let Some(mut resume) = trans.exch.take_data_boxed::() { - let result = match *resume { - ResumeReq::Read(ref mut read) => self.handle_resume_read(read, trans, tw)?, - - ResumeReq::Subscribe(ref mut ctx) => ctx.handle_status_report(trans, tw, self)?, - }; - trans.exch.set_data_boxed(resume); - Ok(result) - } else { - // Nothing to do for now - trans.complete(); - info!("Received status report with status {:?}", req.status); - Ok((OpCode::Reserved, ResponseRequired::No)) - } - } - - fn consume_subscribe( - &self, - rx_buf: &[u8], - trans: &mut Transaction, - tw: &mut TLVWriter, - ) -> Result<(OpCode, ResponseRequired), Error> { - if !trans.exch.is_data_none() { - error!("Exchange data already set!"); - return Err(Error::InvalidState); - } - let ctx = SubsCtx::new(rx_buf, trans, tw, self)?; - trans - .exch - .set_data_boxed(Box::new(ResumeReq::Subscribe(ctx))); - Ok((OpCode::ReportData, ResponseRequired::Yes)) - } -} - -/// Encoder for generating a response to a write request -pub struct AttrWriteEncoder<'a, 'b, 'c> { - tw: &'a mut TLVWriter<'b, 'c>, - tag: TagType, - path: GenericPath, - skip_error: bool, -} -impl<'a, 'b, 'c> AttrWriteEncoder<'a, 'b, 'c> { - pub fn new(tw: &'a mut TLVWriter<'b, 'c>, tag: TagType) -> Self { - Self { - tw, - tag, - path: Default::default(), - skip_error: false, - } - } - - pub fn skip_error(&mut self) { - self.skip_error = true; - } - - pub fn set_path(&mut self, path: GenericPath) { - self.path = path; - } -} - -impl<'a, 'b, 'c> Encoder for AttrWriteEncoder<'a, 'b, 'c> { - fn encode(&mut self, _value: EncodeValue) { - // Only status encodes for AttrWriteResponse - } - - fn encode_status(&mut self, status: IMStatusCode, cluster_status: u16) { - if self.skip_error && status != IMStatusCode::Success { - // Don't encode errors - return; - } - let resp = ib::AttrStatus::new(&self.path, status, cluster_status); - let _ = resp.to_tlv(self.tw, self.tag); - } -} diff --git a/matter/src/data_model/core/read.rs b/matter/src/data_model/core/read.rs deleted file mode 100644 index 07eb1a36..00000000 --- a/matter/src/data_model/core/read.rs +++ /dev/null @@ -1,319 +0,0 @@ -/* - * - * Copyright (c) 2020-2022 Project CHIP Authors - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -use crate::{ - acl::{AccessReq, Accessor}, - data_model::{core::DataModel, objects::*}, - error::*, - interaction_model::{ - core::{IMStatusCode, OpCode}, - messages::{ - ib::{self, DataVersionFilter}, - msg::{self, ReadReq, ReportDataTag::MoreChunkedMsgs, ReportDataTag::SupressResponse}, - GenericPath, - }, - Transaction, - }, - tlv::{self, FromTLV, TLVArray, TLVWriter, TagType, ToTLV}, - transport::{packet::Packet, proto_demux::ResponseRequired}, - utils::writebuf::WriteBuf, - wb_shrink, wb_unshrink, -}; -use log::error; - -/// Encoder for generating a response to a read request -pub struct AttrReadEncoder<'a, 'b, 'c> { - tw: &'a mut TLVWriter<'b, 'c>, - data_ver: u32, - path: GenericPath, - skip_error: bool, - data_ver_filters: Option<&'a TLVArray<'a, DataVersionFilter>>, - is_buffer_full: bool, -} - -impl<'a, 'b, 'c> AttrReadEncoder<'a, 'b, 'c> { - pub fn new(tw: &'a mut TLVWriter<'b, 'c>) -> Self { - Self { - tw, - data_ver: 0, - skip_error: false, - path: Default::default(), - data_ver_filters: None, - is_buffer_full: false, - } - } - - pub fn skip_error(&mut self, skip: bool) { - self.skip_error = skip; - } - - pub fn set_data_ver(&mut self, data_ver: u32) { - self.data_ver = data_ver; - } - - pub fn set_data_ver_filters(&mut self, filters: &'a TLVArray<'a, DataVersionFilter>) { - self.data_ver_filters = Some(filters); - } - - pub fn set_path(&mut self, path: GenericPath) { - self.path = path; - } - - pub fn is_buffer_full(&self) -> bool { - self.is_buffer_full - } -} - -impl<'a, 'b, 'c> Encoder for AttrReadEncoder<'a, 'b, 'c> { - fn encode(&mut self, value: EncodeValue) { - let resp = ib::AttrResp::Data(ib::AttrData::new( - Some(self.data_ver), - ib::AttrPath::new(&self.path), - value, - )); - - let anchor = self.tw.get_tail(); - if resp.to_tlv(self.tw, TagType::Anonymous).is_err() { - self.is_buffer_full = true; - self.tw.rewind_to(anchor); - } - } - - fn encode_status(&mut self, status: IMStatusCode, cluster_status: u16) { - if !self.skip_error { - let resp = - ib::AttrResp::Status(ib::AttrStatus::new(&self.path, status, cluster_status)); - let _ = resp.to_tlv(self.tw, TagType::Anonymous); - } - } -} - -/// State to maintain when a Read Request needs to be resumed -/// resumed - the next chunk of the read needs to be returned -#[derive(Default)] -pub struct ResumeReadReq { - /// The Read Request Attribute Path that caused chunking, and this is the path - /// that needs to be resumed. - pub pending_req: Option>, - - /// The Attribute that couldn't be encoded because our buffer got full. The next chunk - /// will start encoding from this attribute onwards. - /// Note that given wildcard reads, one PendingPath in the member above can generated - /// multiple encode paths. Hence this has to be maintained separately. - pub resume_from: Option, -} -impl ResumeReadReq { - pub fn new(rx_buf: &[u8], resume_from: &Option) -> Result { - let mut packet = Packet::new_rx()?; - let dst = packet.as_borrow_slice(); - - let src_len = rx_buf.len(); - dst[..src_len].copy_from_slice(rx_buf); - packet.get_parsebuf()?.set_len(src_len); - Ok(ResumeReadReq { - pending_req: Some(packet), - resume_from: *resume_from, - }) - } -} - -impl DataModel { - pub fn read_attribute_raw( - &self, - endpoint: EndptId, - cluster: ClusterId, - attr: AttrId, - ) -> Result { - let node = self.node.read().unwrap(); - let cluster = node.get_cluster(endpoint, cluster)?; - cluster.base().read_attribute_raw(attr).map(|a| a.clone()) - } - /// Encode a read attribute from a path that may or may not be wildcard - /// - /// If the buffer gets full while generating the read response, we will return - /// an Err(path), where the path is the path that we should resume from, for the next chunk. - /// This facilitates chunk management - fn handle_read_attr_path( - node: &Node, - accessor: &Accessor, - attr_encoder: &mut AttrReadEncoder, - attr_details: &mut AttrDetails, - resume_from: &mut Option, - ) -> Result<(), Error> { - let mut status = Ok(()); - let path = attr_encoder.path; - - // Skip error reporting for wildcard paths, don't for concrete paths - attr_encoder.skip_error(path.is_wildcard()); - - let result = node.for_each_attribute(&path, |path, c| { - // Ignore processing if data filter matches. - // For a wildcard attribute, this may end happening unnecessarily for all attributes, although - // a single skip for the cluster is sufficient. That requires us to replace this for_each with a - // for_each_cluster - let cluster_data_ver = c.base().get_dataver(); - if Self::data_filter_matches(&attr_encoder.data_ver_filters, path, cluster_data_ver) { - return Ok(()); - } - - // The resume_from indicates that this is the next chunk of a previous Read Request. In such cases, we - // need to skip until we hit this path. - if let Some(r) = resume_from { - // If resume_from is valid, and we haven't hit the resume_from yet, skip encoding - if r != path { - return Ok(()); - } else { - // Else, wipe out the resume_from so subsequent paths can be encoded - *resume_from = None; - } - } - - attr_details.attr_id = path.leaf.unwrap_or_default() as u16; - // Overwrite the previous path with the concrete path - attr_encoder.set_path(*path); - // Set the cluster's data version - attr_encoder.set_data_ver(cluster_data_ver); - let mut access_req = AccessReq::new(accessor, path, Access::READ); - Cluster::read_attribute(c, &mut access_req, attr_encoder, attr_details); - if attr_encoder.is_buffer_full() { - // Buffer is full, next time resume from this attribute - *resume_from = Some(*path); - status = Err(Error::NoSpace); - } - Ok(()) - }); - if let Err(e) = result { - // We hit this only if this is a non-wildcard path - attr_encoder.encode_status(e, 0); - } - status - } - - /// Process an array of Attribute Read Requests - /// - /// When the API returns the chunked read is on, if *resume_from is Some(x) otherwise - /// the read is complete - pub(super) fn handle_read_attr_array( - &self, - read_req: &ReadReq, - trans: &mut Transaction, - old_tw: &mut TLVWriter, - resume_from: &mut Option, - ) -> Result<(), Error> { - let old_wb = old_tw.get_buf(); - // Note, this function may be called from multiple places: a) an actual read - // request, a b) resumed read request, c) subscribe request or d) resumed subscribe - // request. Hopefully 18 is sufficient to address all those scenarios. - // - // This is the amount of space we reserve for other things to be attached towards - // the end - const RESERVE_SIZE: usize = 24; - let mut new_wb = wb_shrink!(old_wb, RESERVE_SIZE); - let mut tw = TLVWriter::new(&mut new_wb); - - let mut attr_encoder = AttrReadEncoder::new(&mut tw); - if let Some(filters) = &read_req.dataver_filters { - attr_encoder.set_data_ver_filters(filters); - } - - if let Some(attr_requests) = &read_req.attr_requests { - let accessor = self.sess_to_accessor(trans.session); - let mut attr_details = AttrDetails::new(accessor.fab_idx, read_req.fabric_filtered); - let node = self.node.read().unwrap(); - attr_encoder - .tw - .start_array(TagType::Context(msg::ReportDataTag::AttributeReports as u8))?; - - let mut result = Ok(()); - for attr_path in attr_requests.iter() { - attr_encoder.set_path(attr_path.to_gp()); - // Extract the attr_path fields into various structures - attr_details.list_index = attr_path.list_index; - result = DataModel::handle_read_attr_path( - &node, - &accessor, - &mut attr_encoder, - &mut attr_details, - resume_from, - ); - if result.is_err() { - break; - } - } - // Now that all the read reports are captured, let's use the old_tw that is - // the full writebuf, and hopefully as all the necessary space to store this - wb_unshrink!(old_wb, new_wb); - old_tw.end_container()?; // Finish the AttrReports - - if result.is_err() { - // If there was an error, indicate chunking. The resume_read_req would have been - // already populated in the loop above. - old_tw.bool(TagType::Context(MoreChunkedMsgs as u8), true)?; - } else { - // A None resume_from indicates no chunking - *resume_from = None; - } - } - Ok(()) - } - - /// Handle a read request - /// - /// This could be called from an actual read request or a resumed read request. Subscription - /// requests do not come to this function. - /// When the API returns the chunked read is on, if *resume_from is Some(x) otherwise - /// the read is complete - pub fn handle_read_req( - &self, - read_req: &ReadReq, - trans: &mut Transaction, - tw: &mut TLVWriter, - resume_from: &mut Option, - ) -> Result<(OpCode, ResponseRequired), Error> { - tw.start_struct(TagType::Anonymous)?; - - self.handle_read_attr_array(read_req, trans, tw, resume_from)?; - - if resume_from.is_none() { - tw.bool(TagType::Context(SupressResponse as u8), true)?; - // Mark transaction complete, if not chunked - trans.complete(); - } - tw.end_container()?; - Ok((OpCode::ReportData, ResponseRequired::Yes)) - } - - /// Handle a resumed read request - pub fn handle_resume_read( - &self, - resume_read_req: &mut ResumeReadReq, - trans: &mut Transaction, - tw: &mut TLVWriter, - ) -> Result<(OpCode, ResponseRequired), Error> { - if let Some(packet) = resume_read_req.pending_req.as_mut() { - let rx_buf = packet.get_parsebuf()?.as_borrow_slice(); - let root = tlv::get_root_node(rx_buf)?; - let req = ReadReq::from_tlv(&root)?; - - self.handle_read_req(&req, trans, tw, &mut resume_read_req.resume_from) - } else { - // No pending req, is that even possible? - error!("This shouldn't have happened"); - Ok((OpCode::Reserved, ResponseRequired::No)) - } - } -} diff --git a/matter/src/data_model/core/subscribe.rs b/matter/src/data_model/core/subscribe.rs deleted file mode 100644 index a65ee1fd..00000000 --- a/matter/src/data_model/core/subscribe.rs +++ /dev/null @@ -1,142 +0,0 @@ -/* - * - * Copyright (c) 2023 Project CHIP Authors - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -use std::sync::atomic::{AtomicU32, Ordering}; - -use crate::{ - error::Error, - interaction_model::{ - core::OpCode, - messages::{ - msg::{self, SubscribeReq, SubscribeResp}, - GenericPath, - }, - }, - tlv::{self, get_root_node_struct, FromTLV, TLVWriter, TagType, ToTLV}, - transport::proto_demux::ResponseRequired, -}; - -use super::{read::ResumeReadReq, DataModel, Transaction}; - -static SUBS_ID: AtomicU32 = AtomicU32::new(1); - -#[derive(PartialEq)] -enum SubsState { - Confirming, - Confirmed, -} - -pub struct SubsCtx { - state: SubsState, - id: u32, - resume_read_req: Option, -} - -impl SubsCtx { - pub fn new( - rx_buf: &[u8], - trans: &mut Transaction, - tw: &mut TLVWriter, - dm: &DataModel, - ) -> Result { - let root = get_root_node_struct(rx_buf)?; - let req = SubscribeReq::from_tlv(&root)?; - - let mut ctx = SubsCtx { - state: SubsState::Confirming, - // TODO - id: SUBS_ID.fetch_add(1, Ordering::SeqCst), - resume_read_req: None, - }; - - let mut resume_from = None; - ctx.do_read(&req, trans, tw, dm, &mut resume_from)?; - if resume_from.is_some() { - // This is a multi-hop read transaction, remember this read request - ctx.resume_read_req = Some(ResumeReadReq::new(rx_buf, &resume_from)?); - } - Ok(ctx) - } - - pub fn handle_status_report( - &mut self, - trans: &mut Transaction, - tw: &mut TLVWriter, - dm: &DataModel, - ) -> Result<(OpCode, ResponseRequired), Error> { - if self.state != SubsState::Confirming { - // Not relevant for us - trans.complete(); - return Err(Error::Invalid); - } - - // Is there a previous resume read pending - if self.resume_read_req.is_some() { - let mut resume_read_req = self.resume_read_req.take().unwrap(); - if let Some(packet) = resume_read_req.pending_req.as_mut() { - let rx_buf = packet.get_parsebuf()?.as_borrow_slice(); - let root = tlv::get_root_node(rx_buf)?; - let req = SubscribeReq::from_tlv(&root)?; - - self.do_read(&req, trans, tw, dm, &mut resume_read_req.resume_from)?; - if resume_read_req.resume_from.is_some() { - // More chunks are pending, setup resume_read_req again - self.resume_read_req = Some(resume_read_req); - } - - return Ok((OpCode::ReportData, ResponseRequired::Yes)); - } - } - - // We are here implies that the read is now complete - self.confirm_subscription(trans, tw) - } - - fn confirm_subscription( - &mut self, - trans: &mut Transaction, - tw: &mut TLVWriter, - ) -> Result<(OpCode, ResponseRequired), Error> { - self.state = SubsState::Confirmed; - - // TODO - let resp = SubscribeResp::new(self.id, 40); - resp.to_tlv(tw, TagType::Anonymous)?; - trans.complete(); - Ok((OpCode::SubscriptResponse, ResponseRequired::Yes)) - } - - fn do_read( - &mut self, - req: &SubscribeReq, - trans: &mut Transaction, - tw: &mut TLVWriter, - dm: &DataModel, - resume_from: &mut Option, - ) -> Result<(), Error> { - let read_req = req.to_read_req(); - tw.start_struct(TagType::Anonymous)?; - tw.u32( - TagType::Context(msg::ReportDataTag::SubscriptionId as u8), - self.id, - )?; - dm.handle_read_attr_array(&read_req, trans, tw, resume_from)?; - tw.end_container()?; - - Ok(()) - } -} diff --git a/matter/src/data_model/device_types.rs b/matter/src/data_model/device_types.rs index 9c379710..fbce4942 100644 --- a/matter/src/data_model/device_types.rs +++ b/matter/src/data_model/device_types.rs @@ -15,60 +15,14 @@ * limitations under the License. */ -use super::cluster_basic_information::BasicInfoCluster; -use super::cluster_basic_information::BasicInfoConfig; -use super::cluster_on_off::OnOffCluster; -use super::objects::*; -use super::sdm::admin_commissioning::AdminCommCluster; -use super::sdm::dev_att::DevAttDataFetcher; -use super::sdm::general_commissioning::GenCommCluster; -use super::sdm::noc::NocCluster; -use super::sdm::nw_commissioning::NwCommCluster; -use super::system_model::access_control::AccessControlCluster; -use crate::acl::AclMgr; -use crate::error::*; -use crate::fabric::FabricMgr; -use crate::secure_channel::pake::PaseMgr; -use std::sync::Arc; -use std::sync::RwLockWriteGuard; +use super::objects::DeviceType; pub const DEV_TYPE_ROOT_NODE: DeviceType = DeviceType { dtype: 0x0016, drev: 1, }; -type WriteNode<'a> = RwLockWriteGuard<'a, Box>; - -pub fn device_type_add_root_node( - node: &mut WriteNode, - dev_info: BasicInfoConfig, - dev_att: Box, - fabric_mgr: Arc, - acl_mgr: Arc, - pase_mgr: PaseMgr, -) -> Result { - // Add the root endpoint - let endpoint = node.add_endpoint(DEV_TYPE_ROOT_NODE)?; - if endpoint != 0 { - // Somehow endpoint 0 was already added, this shouldn't be the case - return Err(Error::Invalid); - }; - // Add the mandatory clusters - node.add_cluster(0, BasicInfoCluster::new(dev_info)?)?; - let general_commissioning = GenCommCluster::new()?; - let failsafe = general_commissioning.failsafe(); - node.add_cluster(0, general_commissioning)?; - node.add_cluster(0, NwCommCluster::new()?)?; - node.add_cluster(0, AdminCommCluster::new(pase_mgr)?)?; - node.add_cluster( - 0, - NocCluster::new(dev_att, fabric_mgr, acl_mgr.clone(), failsafe)?, - )?; - node.add_cluster(0, AccessControlCluster::new(acl_mgr)?)?; - Ok(endpoint) -} - -const DEV_TYPE_ON_OFF_LIGHT: DeviceType = DeviceType { +pub const DEV_TYPE_ON_OFF_LIGHT: DeviceType = DeviceType { dtype: 0x0100, drev: 2, }; @@ -77,9 +31,3 @@ pub const DEV_TYPE_ON_SMART_SPEAKER: DeviceType = DeviceType { dtype: 0x0022, drev: 2, }; - -pub fn device_type_add_on_off_light(node: &mut WriteNode) -> Result { - let endpoint = node.add_endpoint(DEV_TYPE_ON_OFF_LIGHT)?; - node.add_cluster(endpoint, OnOffCluster::new()?)?; - Ok(endpoint) -} diff --git a/matter/src/data_model/mod.rs b/matter/src/data_model/mod.rs index c3479419..c76e07cf 100644 --- a/matter/src/data_model/mod.rs +++ b/matter/src/data_model/mod.rs @@ -20,8 +20,9 @@ pub mod device_types; pub mod objects; pub mod cluster_basic_information; -pub mod cluster_media_playback; +// TODO pub mod cluster_media_playback; pub mod cluster_on_off; pub mod cluster_template; +pub mod root_endpoint; pub mod sdm; pub mod system_model; diff --git a/matter/src/data_model/objects/attribute.rs b/matter/src/data_model/objects/attribute.rs index 28be5c64..69fd5446 100644 --- a/matter/src/data_model/objects/attribute.rs +++ b/matter/src/data_model/objects/attribute.rs @@ -15,15 +15,11 @@ * limitations under the License. */ -use super::{AttrId, GlobalElements, Privilege}; -use crate::{ - error::*, - // TODO: This layer shouldn't really depend on the TLV layer, should create an abstraction layer - tlv::{TLVElement, TLVWriter, TagType, ToTLV}, -}; +use crate::data_model::objects::GlobalElements; + +use super::{AttrId, Privilege}; use bitflags::bitflags; -use log::error; -use std::fmt::{self, Debug, Formatter}; +use core::fmt::{self, Debug}; bitflags! { #[derive(Default)] @@ -83,110 +79,24 @@ bitflags! { } } -/* This file needs some major revamp. - * - instead of allocating all over the heap, we should use some kind of slab/block allocator - * - instead of arrays, can use linked-lists to conserve space and avoid the internal fragmentation - */ - -#[derive(PartialEq, PartialOrd, Clone)] -pub enum AttrValue { - Int64(i64), - Uint8(u8), - Uint16(u16), - Uint32(u32), - Uint64(u64), - Bool(bool), - Utf8(String), - Custom, -} - -impl Debug for AttrValue { - fn fmt(&self, f: &mut Formatter<'_>) -> Result<(), std::fmt::Error> { - match &self { - AttrValue::Int64(v) => write!(f, "{:?}", *v), - AttrValue::Uint8(v) => write!(f, "{:?}", *v), - AttrValue::Uint16(v) => write!(f, "{:?}", *v), - AttrValue::Uint32(v) => write!(f, "{:?}", *v), - AttrValue::Uint64(v) => write!(f, "{:?}", *v), - AttrValue::Bool(v) => write!(f, "{:?}", *v), - AttrValue::Utf8(v) => write!(f, "{:?}", *v), - AttrValue::Custom => write!(f, "custom-attribute"), - }?; - Ok(()) - } -} - -impl ToTLV for AttrValue { - fn to_tlv(&self, tw: &mut TLVWriter, tag_type: TagType) -> Result<(), Error> { - // What is the time complexity of such long match statements? - match self { - AttrValue::Bool(v) => tw.bool(tag_type, *v), - AttrValue::Uint8(v) => tw.u8(tag_type, *v), - AttrValue::Uint16(v) => tw.u16(tag_type, *v), - AttrValue::Uint32(v) => tw.u32(tag_type, *v), - AttrValue::Uint64(v) => tw.u64(tag_type, *v), - AttrValue::Utf8(v) => tw.utf8(tag_type, v.as_bytes()), - _ => { - error!("Attribute type not yet supported"); - Err(Error::AttributeNotFound) - } - } - } -} - -impl AttrValue { - pub fn update_from_tlv(&mut self, tr: &TLVElement) -> Result<(), Error> { - match self { - AttrValue::Bool(v) => *v = tr.bool()?, - AttrValue::Uint8(v) => *v = tr.u8()?, - AttrValue::Uint16(v) => *v = tr.u16()?, - AttrValue::Uint32(v) => *v = tr.u32()?, - AttrValue::Uint64(v) => *v = tr.u64()?, - _ => { - error!("Attribute type not yet supported"); - return Err(Error::AttributeNotFound); - } - } - Ok(()) - } -} - #[derive(Debug, Clone)] pub struct Attribute { - pub(super) id: AttrId, - pub(super) value: AttrValue, - pub(super) quality: Quality, - pub(super) access: Access, -} - -impl Default for Attribute { - fn default() -> Attribute { - Attribute { - id: 0, - value: AttrValue::Bool(true), - quality: Default::default(), - access: Default::default(), - } - } + pub id: AttrId, + pub quality: Quality, + pub access: Access, } impl Attribute { - pub fn new(id: AttrId, value: AttrValue, access: Access, quality: Quality) -> Self { - Attribute { + pub const fn new(id: AttrId, access: Access, quality: Quality) -> Self { + Self { id, - value, access, quality, } } - pub fn set_value(&mut self, value: AttrValue) -> Result<(), Error> { - if !self.quality.contains(Quality::FIXED) { - self.value = value; - Ok(()) - } else { - Err(Error::Invalid) - } + pub fn is_system(&self) -> bool { + Self::is_system_attr(self.id) } pub fn is_system_attr(attr_id: AttrId) -> bool { @@ -194,9 +104,9 @@ impl Attribute { } } -impl std::fmt::Display for Attribute { +impl core::fmt::Display for Attribute { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - write!(f, "{}: {:?}", self.id, self.value) + write!(f, "{}", self.id) } } diff --git a/matter/src/data_model/objects/cluster.rs b/matter/src/data_model/objects/cluster.rs index 7ca8350e..90c6835d 100644 --- a/matter/src/data_model/objects/cluster.rs +++ b/matter/src/data_model/objects/cluster.rs @@ -15,25 +15,31 @@ * limitations under the License. */ +use log::error; +use strum::FromRepr; + use crate::{ - acl::AccessReq, - data_model::objects::{Access, AttrValue, Attribute, EncodeValue, Quality}, - error::*, - interaction_model::{command::CommandReq, core::IMStatusCode}, + acl::{AccessReq, Accessor}, + attribute_enum, + data_model::objects::*, + error::Error, + interaction_model::{ + core::IMStatusCode, + messages::{ + ib::{AttrPath, AttrStatus, CmdPath, CmdStatus}, + GenericPath, + }, + }, // TODO: This layer shouldn't really depend on the TLV layer, should create an abstraction layer - tlv::{Nullable, TLVElement, TLVWriter, TagType}, + tlv::{Nullable, TLVWriter, TagType}, +}; +use core::{ + convert::TryInto, + fmt::{self, Debug}, }; -use log::error; -use num_derive::FromPrimitive; -use rand::Rng; -use std::fmt::{self, Debug}; - -use super::{AttrId, ClusterId, Encoder}; - -pub const ATTRS_PER_CLUSTER: usize = 10; -pub const CMDS_PER_CLUSTER: usize = 8; -#[derive(FromPrimitive, Debug)] +#[derive(Clone, Copy, Debug, Eq, PartialEq, FromRepr)] +#[repr(u16)] pub enum GlobalElements { _ClusterRevision = 0xFFFD, FeatureMap = 0xFFFC, @@ -44,297 +50,308 @@ pub enum GlobalElements { FabricIndex = 0xFE, } +attribute_enum!(GlobalElements); + +pub const FEATURE_MAP: Attribute = + Attribute::new(GlobalElements::FeatureMap as _, Access::RV, Quality::NONE); + +pub const ATTRIBUTE_LIST: Attribute = Attribute::new( + GlobalElements::AttributeList as _, + Access::RV, + Quality::NONE, +); + // TODO: What if we instead of creating this, we just pass the AttrData/AttrPath to the read/write // methods? /// The Attribute Details structure records the details about the attribute under consideration. -/// Typically this structure is progressively built as we proceed through the request processing. -pub struct AttrDetails { - /// Fabric Filtering Activated - pub fab_filter: bool, - /// The current Fabric Index - pub fab_idx: u8, - /// List Index, if any - pub list_index: Option>, +pub struct AttrDetails<'a> { + pub node: &'a Node<'a>, + /// The actual endpoint ID + pub endpoint_id: EndptId, + /// The actual cluster ID + pub cluster_id: ClusterId, /// The actual attribute ID pub attr_id: AttrId, + /// List Index, if any + pub list_index: Option>, + /// The current Fabric Index + pub fab_idx: u8, + /// Fabric Filtering Activated + pub fab_filter: bool, + pub dataver: Option, + pub wildcard: bool, } -impl AttrDetails { - pub fn new(fab_idx: u8, fab_filter: bool) -> Self { - Self { - fab_filter, - fab_idx, - list_index: None, - attr_id: 0, - } +impl<'a> AttrDetails<'a> { + pub fn is_system(&self) -> bool { + Attribute::is_system_attr(self.attr_id) } -} -pub trait ClusterType { - // TODO: 5 methods is going to be quite expensive for vtables of all the clusters - fn base(&self) -> &Cluster; - fn base_mut(&mut self) -> &mut Cluster; - fn read_custom_attribute(&self, _encoder: &mut dyn Encoder, _attr: &AttrDetails) {} - - fn handle_command(&mut self, cmd_req: &mut CommandReq) -> Result<(), IMStatusCode> { - let cmd = cmd_req.cmd.path.leaf.map(|a| a as u16); - println!("Received command: {:?}", cmd); + pub fn path(&self) -> AttrPath { + AttrPath { + endpoint: Some(self.endpoint_id), + cluster: Some(self.cluster_id), + attr: Some(self.attr_id), + list_index: self.list_index, + ..Default::default() + } + } - Err(IMStatusCode::UnsupportedCommand) + pub fn status(&self, status: IMStatusCode) -> Result, Error> { + if self.should_report(status) { + Ok(Some(AttrStatus::new( + &GenericPath { + endpoint: Some(self.endpoint_id), + cluster: Some(self.cluster_id), + leaf: Some(self.attr_id as _), + }, + status, + 0, + ))) + } else { + Ok(None) + } } - /// Write an attribute - /// - /// Note that if this method is defined, you must handle the write for all the attributes. Even those - /// that are not 'custom'. This is different from how you handle the read_custom_attribute() method. - /// The reason for this being, you may want to handle an attribute write request even though it is a - /// standard attribute like u16, u32 etc. - /// - /// If you wish to update the standard attribute in the data model database, you must call the - /// write_attribute_from_tlv() method from the base cluster, as is shown here in the default case - fn write_attribute( - &mut self, - attr: &AttrDetails, - data: &TLVElement, - ) -> Result<(), IMStatusCode> { - self.base_mut().write_attribute_from_tlv(attr.attr_id, data) + fn should_report(&self, status: IMStatusCode) -> bool { + !self.wildcard + || !matches!( + status, + IMStatusCode::UnsupportedEndpoint + | IMStatusCode::UnsupportedCluster + | IMStatusCode::UnsupportedAttribute + | IMStatusCode::UnsupportedCommand + | IMStatusCode::UnsupportedAccess + | IMStatusCode::UnsupportedRead + | IMStatusCode::UnsupportedWrite + | IMStatusCode::DataVersionMismatch + ) } } -pub struct Cluster { - pub(super) id: ClusterId, - attributes: Vec, - data_ver: u32, +pub struct CmdDetails<'a> { + pub node: &'a Node<'a>, + pub endpoint_id: EndptId, + pub cluster_id: ClusterId, + pub cmd_id: CmdId, + pub wildcard: bool, } -impl Cluster { - pub fn new(id: ClusterId) -> Result { - let mut c = Cluster { - id, - attributes: Vec::with_capacity(ATTRS_PER_CLUSTER), - data_ver: rand::thread_rng().gen_range(0..0xFFFFFFFF), - }; - c.add_default_attributes()?; - Ok(c) - } - - pub fn id(&self) -> ClusterId { - self.id - } - - pub fn get_dataver(&self) -> u32 { - self.data_ver - } - - pub fn set_feature_map(&mut self, map: u32) -> Result<(), Error> { - self.write_attribute_raw(GlobalElements::FeatureMap as u16, AttrValue::Uint32(map)) - .map_err(|_| Error::Invalid)?; - Ok(()) - } - - fn add_default_attributes(&mut self) -> Result<(), Error> { - // Default feature map is 0 - self.add_attribute(Attribute::new( - GlobalElements::FeatureMap as u16, - AttrValue::Uint32(0), - Access::RV, - Quality::NONE, - ))?; - - self.add_attribute(Attribute::new( - GlobalElements::AttributeList as u16, - AttrValue::Custom, - Access::RV, - Quality::NONE, - )) +impl<'a> CmdDetails<'a> { + pub fn path(&self) -> CmdPath { + CmdPath::new( + Some(self.endpoint_id), + Some(self.cluster_id), + Some(self.cmd_id), + ) } - pub fn add_attributes(&mut self, attrs: &[Attribute]) -> Result<(), Error> { - if self.attributes.len() + attrs.len() <= self.attributes.capacity() { - self.attributes.extend_from_slice(attrs); - Ok(()) + pub fn success(&self, tracker: &CmdDataTracker) -> Option { + if tracker.needs_status() { + self.status(IMStatusCode::Success) } else { - Err(Error::NoSpace) + None } } - pub fn add_attribute(&mut self, attr: Attribute) -> Result<(), Error> { - if self.attributes.len() < self.attributes.capacity() { - self.attributes.push(attr); - Ok(()) + pub fn status(&self, status: IMStatusCode) -> Option { + if self.should_report(status) { + Some(CmdStatus::new( + CmdPath::new( + Some(self.endpoint_id), + Some(self.cluster_id), + Some(self.cmd_id), + ), + status, + 0, + )) } else { - Err(Error::NoSpace) + None } } - fn get_attribute_index(&self, attr_id: AttrId) -> Option { - self.attributes.iter().position(|c| c.id == attr_id) - } - - fn get_attribute(&self, attr_id: AttrId) -> Result<&Attribute, Error> { - let index = self - .get_attribute_index(attr_id) - .ok_or(Error::AttributeNotFound)?; - Ok(&self.attributes[index]) + fn should_report(&self, status: IMStatusCode) -> bool { + !self.wildcard + || !matches!( + status, + IMStatusCode::UnsupportedEndpoint + | IMStatusCode::UnsupportedCluster + | IMStatusCode::UnsupportedAttribute + | IMStatusCode::UnsupportedCommand + | IMStatusCode::UnsupportedAccess + | IMStatusCode::UnsupportedRead + | IMStatusCode::UnsupportedWrite + ) } +} - fn get_attribute_mut(&mut self, attr_id: AttrId) -> Result<&mut Attribute, Error> { - let index = self - .get_attribute_index(attr_id) - .ok_or(Error::AttributeNotFound)?; - Ok(&mut self.attributes[index]) - } +#[derive(Debug, Clone)] +pub struct Cluster<'a> { + pub id: ClusterId, + pub feature_map: u32, + pub attributes: &'a [Attribute], + pub commands: &'a [CmdId], +} - // Returns a slice of attribute, with either a single attribute or all (wildcard) - pub fn get_wildcard_attribute( - &self, - attribute: Option, - ) -> Result<(&[Attribute], bool), IMStatusCode> { - if let Some(a) = attribute { - if let Some(i) = self.get_attribute_index(a) { - Ok((&self.attributes[i..i + 1], false)) - } else { - Err(IMStatusCode::UnsupportedAttribute) - } - } else { - Ok((&self.attributes[..], true)) +impl<'a> Cluster<'a> { + pub const fn new( + id: ClusterId, + feature_map: u32, + attributes: &'a [Attribute], + commands: &'a [CmdId], + ) -> Self { + Self { + id, + feature_map, + attributes, + commands, } } - pub fn read_attribute( - c: &dyn ClusterType, - access_req: &mut AccessReq, - encoder: &mut dyn Encoder, - attr: &AttrDetails, - ) { - let mut error = IMStatusCode::Success; - let base = c.base(); - let a = if let Ok(a) = base.get_attribute(attr.attr_id) { - a - } else { - encoder.encode_status(IMStatusCode::UnsupportedAttribute, 0); - return; - }; - - if !a.access.contains(Access::READ) { - error = IMStatusCode::UnsupportedRead; - } - - access_req.set_target_perms(a.access); - if !access_req.allow() { - error = IMStatusCode::UnsupportedAccess; - } - - if error != IMStatusCode::Success { - encoder.encode_status(error, 0); - } else if Attribute::is_system_attr(attr.attr_id) { - c.base().read_system_attribute(encoder, a) - } else if a.value != AttrValue::Custom { - encoder.encode(EncodeValue::Value(&a.value)) - } else { - c.read_custom_attribute(encoder, attr) - } + pub(crate) fn match_attributes<'m>( + &'m self, + accessor: &'m Accessor<'m>, + ep: EndptId, + attr: Option, + write: bool, + ) -> impl Iterator + 'm { + self.attributes + .iter() + .filter(move |attribute| attr.map(|attr| attr == attribute.id).unwrap_or(true)) + .filter(move |attribute| { + let mut access_req = AccessReq::new( + accessor, + GenericPath::new(Some(ep), Some(self.id), Some(attribute.id as _)), + if write { Access::WRITE } else { Access::READ }, + ); + self.check_attr_access(&mut access_req, attribute.access) + .is_ok() + }) + .map(|attribute| attribute.id) } - fn encode_attribute_ids(&self, tag: TagType, tw: &mut TLVWriter) { - let _ = tw.start_array(tag); - for a in &self.attributes { - let _ = tw.u16(TagType::Anonymous, a.id); - } - let _ = tw.end_container(); + pub fn match_commands<'m>( + &'m self, + accessor: &'m Accessor<'m>, + ep: EndptId, + cmd: Option, + ) -> impl Iterator + 'm { + self.commands + .iter() + .filter(move |id| cmd.map(|cmd| **id == cmd).unwrap_or(true)) + .filter(move |id| { + let mut access_req = AccessReq::new( + accessor, + GenericPath::new(Some(ep), Some(self.id), Some(**id as _)), + Access::WRITE, + ); + self.check_cmd_access(&mut access_req).is_ok() + }) + .copied() } - fn read_system_attribute(&self, encoder: &mut dyn Encoder, attr: &Attribute) { - let global_attr: Option = num::FromPrimitive::from_u16(attr.id); - if let Some(global_attr) = global_attr { - match global_attr { - GlobalElements::AttributeList => { - encoder.encode(EncodeValue::Closure(&|tag, tw| { - self.encode_attribute_ids(tag, tw) - })); - return; - } - GlobalElements::FeatureMap => { - encoder.encode(EncodeValue::Value(&attr.value)); - return; - } - _ => { - error!("This attribute not yet handled {:?}", global_attr); - } - } - } - encoder.encode_status(IMStatusCode::UnsupportedAttribute, 0) + pub(crate) fn check_attribute( + &self, + accessor: &Accessor, + ep: EndptId, + attr: AttrId, + write: bool, + ) -> Result<(), IMStatusCode> { + let attribute = self + .attributes + .iter() + .find(|attribute| attribute.id == attr) + .ok_or(IMStatusCode::UnsupportedAttribute)?; + + let mut access_req = AccessReq::new( + accessor, + GenericPath::new(Some(ep), Some(self.id), Some(attr as _)), + if write { Access::WRITE } else { Access::READ }, + ); + + self.check_attr_access(&mut access_req, attribute.access) } - pub fn read_attribute_raw(&self, attr_id: AttrId) -> Result<&AttrValue, IMStatusCode> { - let a = self - .get_attribute(attr_id) - .map_err(|_| IMStatusCode::UnsupportedAttribute)?; - Ok(&a.value) + pub(crate) fn check_command( + &self, + accessor: &Accessor, + ep: EndptId, + cmd: CmdId, + ) -> Result<(), IMStatusCode> { + self.commands + .iter() + .find(|id| **id == cmd) + .ok_or(IMStatusCode::UnsupportedCommand)?; + + let mut access_req = AccessReq::new( + accessor, + GenericPath::new(Some(ep), Some(self.id), Some(cmd as _)), + Access::WRITE, + ); + + self.check_cmd_access(&mut access_req) } - pub fn write_attribute( - c: &mut dyn ClusterType, + fn check_attr_access( + &self, access_req: &mut AccessReq, - data: &TLVElement, - attr: &AttrDetails, + target_perms: Access, ) -> Result<(), IMStatusCode> { - let base = c.base_mut(); - let a = if let Ok(a) = base.get_attribute_mut(attr.attr_id) { - a - } else { - return Err(IMStatusCode::UnsupportedAttribute); - }; - - if !a.access.contains(Access::WRITE) { - return Err(IMStatusCode::UnsupportedWrite); + if !target_perms.contains(access_req.operation()) { + Err(if matches!(access_req.operation(), Access::WRITE) { + IMStatusCode::UnsupportedWrite + } else { + IMStatusCode::UnsupportedRead + })?; } - access_req.set_target_perms(a.access); - if !access_req.allow() { - return Err(IMStatusCode::UnsupportedAccess); + access_req.set_target_perms(target_perms); + if access_req.allow() { + Ok(()) + } else { + Err(IMStatusCode::UnsupportedAccess) } - - c.write_attribute(attr, data) } - pub fn write_attribute_from_tlv( - &mut self, - attr_id: AttrId, - data: &TLVElement, - ) -> Result<(), IMStatusCode> { - let a = self.get_attribute_mut(attr_id)?; - if a.value != AttrValue::Custom { - let mut value = a.value.clone(); - value - .update_from_tlv(data) - .map_err(|_| IMStatusCode::Failure)?; - a.set_value(value) - .map(|_| { - self.cluster_changed(); - }) - .map_err(|_| IMStatusCode::UnsupportedWrite) + fn check_cmd_access(&self, access_req: &mut AccessReq) -> Result<(), IMStatusCode> { + access_req.set_target_perms( + Access::WRITE + .union(Access::NEED_OPERATE) + .union(Access::NEED_MANAGE) + .union(Access::NEED_ADMIN), + ); // TODO + if access_req.allow() { + Ok(()) } else { - Err(IMStatusCode::UnsupportedAttribute) + Err(IMStatusCode::UnsupportedAccess) } } - pub fn write_attribute_raw(&mut self, attr_id: AttrId, value: AttrValue) -> Result<(), Error> { - let a = self.get_attribute_mut(attr_id)?; - a.set_value(value).map(|_| { - self.cluster_changed(); - }) + pub fn read(&self, attr: AttrId, mut writer: AttrDataWriter) -> Result<(), Error> { + match attr.try_into()? { + GlobalElements::AttributeList => { + self.encode_attribute_ids(AttrDataWriter::TAG, &mut writer)?; + writer.complete() + } + GlobalElements::FeatureMap => writer.set(self.feature_map), + other => { + error!("This attribute is not yet handled {:?}", other); + Err(Error::AttributeNotFound) + } + } } - /// This method must be called for any changes to the data model - /// Currently this only increments the data version, but we can reuse the same - /// for raising events too - pub fn cluster_changed(&mut self) { - self.data_ver = self.data_ver.wrapping_add(1); + fn encode_attribute_ids(&self, tag: TagType, tw: &mut TLVWriter) -> Result<(), Error> { + tw.start_array(tag)?; + for a in self.attributes { + tw.u16(TagType::Anonymous, a.id)?; + } + + tw.end_container() } } -impl std::fmt::Display for Cluster { +impl<'a> core::fmt::Display for Cluster<'a> { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { write!(f, "id:{}, ", self.id)?; write!(f, "attrs[")?; diff --git a/matter/src/data_model/objects/dataver.rs b/matter/src/data_model/objects/dataver.rs new file mode 100644 index 00000000..fc062be0 --- /dev/null +++ b/matter/src/data_model/objects/dataver.rs @@ -0,0 +1,55 @@ +/* + * + * Copyright (c) 2020-2022 Project CHIP Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + + use crate::utils::rand::Rand; + +pub struct Dataver { + ver: u32, + changed: bool, +} + +impl Dataver { + pub fn new(rand: Rand) -> Self { + let mut buf = [0; 4]; + rand(&mut buf); + + Self { + ver: u32::from_be_bytes(buf), + changed: false, + } + } + + pub fn get(&self) -> u32 { + self.ver + } + + pub fn changed(&mut self) -> u32 { + (self.ver, _) = self.ver.overflowing_add(1); + self.changed = true; + + self.get() + } + + pub fn consume_change(&mut self, change: T) -> Option { + if self.changed { + self.changed = false; + Some(change) + } else { + None + } + } +} diff --git a/matter/src/data_model/objects/encoder.rs b/matter/src/data_model/objects/encoder.rs index d5653169..39d2ba6d 100644 --- a/matter/src/data_model/objects/encoder.rs +++ b/matter/src/data_model/objects/encoder.rs @@ -15,17 +15,26 @@ * limitations under the License. */ -use std::fmt::{Debug, Formatter}; +use core::fmt::{Debug, Formatter}; +use core::marker::PhantomData; +use core::ops::{Deref, DerefMut}; +use crate::interaction_model::core::{IMStatusCode, Transaction}; +use crate::interaction_model::messages::ib::{ + AttrPath, AttrResp, AttrStatus, CmdDataTag, CmdPath, CmdStatus, InvResp, InvRespTag, +}; +use crate::tlv::UtfStr; use crate::{ error::Error, - interaction_model::core::IMStatusCode, + interaction_model::messages::ib::{AttrDataTag, AttrRespTag}, tlv::{FromTLV, TLVElement, TLVWriter, TagType, ToTLV}, }; use log::error; +use super::{AttrDetails, CmdDetails, Handler}; + // TODO: Should this return an IMStatusCode Error? But if yes, the higher layer -// may have already started encoding the 'success' headers, we might not to manage +// may have already started encoding the 'success' headers, we might not want to manage // the tw.rewind() in that case, if we add this support pub type EncodeValueGen<'a> = &'a dyn Fn(TagType, &mut TLVWriter); @@ -78,7 +87,7 @@ impl<'a> PartialEq for EncodeValue<'a> { } impl<'a> Debug for EncodeValue<'a> { - fn fmt(&self, f: &mut Formatter<'_>) -> Result<(), std::fmt::Error> { + fn fmt(&self, f: &mut Formatter<'_>) -> Result<(), core::fmt::Error> { match *self { EncodeValue::Closure(_) => write!(f, "Contains closure"), EncodeValue::Tlv(t) => write!(f, "{:?}", t), @@ -107,17 +116,454 @@ impl<'a> FromTLV<'a> for EncodeValue<'a> { } } -/// An object that can encode EncodeValue into the necessary hierarchical structure -/// as expected by the Interaction Model -pub trait Encoder { - /// Encode a given value - fn encode(&mut self, value: EncodeValue); - /// Encode a status report - fn encode_status(&mut self, status: IMStatusCode, cluster_status: u16); +pub struct AttrDataEncoder<'a, 'b, 'c> { + dataver_filter: Option, + path: AttrPath, + tw: &'a mut TLVWriter<'b, 'c>, +} + +impl<'a, 'b, 'c> AttrDataEncoder<'a, 'b, 'c> { + pub fn handle_read( + item: Result, + handler: &T, + tw: &mut TLVWriter, + ) -> Result<(), Error> { + let status = match item { + Ok(attr) => { + let encoder = AttrDataEncoder::new(&attr, tw); + + match handler.read(&attr, encoder) { + Ok(()) => None, + Err(error) => attr.status(error.into())?, + } + } + Err(status) => Some(status), + }; + + if let Some(status) = status { + AttrResp::Status(status).to_tlv(tw, TagType::Anonymous)?; + } + + Ok(()) + } + + pub fn handle_write( + item: Result<(AttrDetails, TLVElement), AttrStatus>, + handler: &mut T, + tw: &mut TLVWriter, + ) -> Result<(), Error> { + let status = match item { + Ok((attr, data)) => match handler.write(&attr, AttrData::new(attr.dataver, &data)) { + Ok(()) => attr.status(IMStatusCode::Success)?, + Err(error) => attr.status(error.into())?, + }, + Err(status) => Some(status), + }; + + if let Some(status) = status { + status.to_tlv(tw, TagType::Anonymous)?; + } + + Ok(()) + } + + #[cfg(feature = "nightly")] + pub async fn handle_read_async( + item: Result, AttrStatus>, + handler: &T, + tw: &mut TLVWriter<'_, '_>, + ) -> Result<(), Error> { + let status = match item { + Ok(attr) => { + let encoder = AttrDataEncoder::new(&attr, tw); + + match handler.read(&attr, encoder).await { + Ok(()) => None, + Err(error) => attr.status(error.into())?, + } + } + Err(status) => Some(status), + }; + + if let Some(status) = status { + AttrResp::Status(status).to_tlv(tw, TagType::Anonymous)?; + } + + Ok(()) + } + + #[cfg(feature = "nightly")] + pub async fn handle_write_async( + item: Result<(AttrDetails<'_>, TLVElement<'_>), AttrStatus>, + handler: &mut T, + tw: &mut TLVWriter<'_, '_>, + ) -> Result<(), Error> { + let status = match item { + Ok((attr, data)) => match handler + .write(&attr, AttrData::new(attr.dataver, &data)) + .await + { + Ok(()) => attr.status(IMStatusCode::Success)?, + Err(error) => attr.status(error.into())?, + }, + Err(status) => Some(status), + }; + + if let Some(status) = status { + AttrResp::Status(status).to_tlv(tw, TagType::Anonymous)?; + } + + Ok(()) + } + + pub fn new(attr: &AttrDetails, tw: &'a mut TLVWriter<'b, 'c>) -> Self { + Self { + dataver_filter: attr.dataver, + path: attr.path(), + tw, + } + } + + pub fn with_dataver(self, dataver: u32) -> Result>, Error> { + if self + .dataver_filter + .map(|dataver_filter| dataver_filter != dataver) + .unwrap_or(true) + { + let mut writer = AttrDataWriter::new(self.tw); + + writer.start_struct(TagType::Anonymous)?; + writer.start_struct(TagType::Context(AttrRespTag::Data as _))?; + writer.u32(TagType::Context(AttrDataTag::DataVer as _), dataver)?; + self.path + .to_tlv(&mut writer, TagType::Context(AttrDataTag::Path as _))?; + + Ok(Some(writer)) + } else { + Ok(None) + } + } +} + +pub struct AttrDataWriter<'a, 'b, 'c> { + tw: &'a mut TLVWriter<'b, 'c>, + anchor: usize, + completed: bool, +} + +impl<'a, 'b, 'c> AttrDataWriter<'a, 'b, 'c> { + pub const TAG: TagType = TagType::Context(AttrDataTag::Data as _); + + fn new(tw: &'a mut TLVWriter<'b, 'c>) -> Self { + let anchor = tw.get_tail(); + + Self { + tw, + anchor, + completed: false, + } + } + + pub fn set(self, value: T) -> Result<(), Error> { + value.to_tlv(self.tw, Self::TAG)?; + self.complete() + } + + pub fn complete(mut self) -> Result<(), Error> { + self.tw.end_container()?; + self.tw.end_container()?; + + self.completed = true; + + Ok(()) + } + + fn reset(&mut self) { + self.tw.rewind_to(self.anchor); + } +} + +impl<'a, 'b, 'c> Drop for AttrDataWriter<'a, 'b, 'c> { + fn drop(&mut self) { + if !self.completed { + self.reset(); + } + } +} + +impl<'a, 'b, 'c> Deref for AttrDataWriter<'a, 'b, 'c> { + type Target = TLVWriter<'b, 'c>; + + fn deref(&self) -> &Self::Target { + self.tw + } +} + +impl<'a, 'b, 'c> DerefMut for AttrDataWriter<'a, 'b, 'c> { + fn deref_mut(&mut self) -> &mut Self::Target { + self.tw + } +} + +pub struct AttrData<'a> { + for_dataver: Option, + data: &'a TLVElement<'a>, +} + +impl<'a> AttrData<'a> { + pub fn new(for_dataver: Option, data: &'a TLVElement<'a>) -> Self { + Self { for_dataver, data } + } + + pub fn with_dataver(self, dataver: u32) -> Result<&'a TLVElement<'a>, Error> { + if let Some(req_dataver) = self.for_dataver { + if req_dataver != dataver { + return Err(Error::DataVersionMismatch); + } + } + + Ok(self.data) + } +} + +#[derive(Default)] +pub struct CmdDataTracker { + skip_status: bool, +} + +impl CmdDataTracker { + pub const fn new() -> Self { + Self { skip_status: false } + } + + pub(crate) fn complete(&mut self) { + self.skip_status = true; + } + + pub fn needs_status(&self) -> bool { + !self.skip_status + } +} + +pub struct CmdDataEncoder<'a, 'b, 'c> { + tracker: &'a mut CmdDataTracker, + path: CmdPath, + tw: &'a mut TLVWriter<'b, 'c>, +} + +impl<'a, 'b, 'c> CmdDataEncoder<'a, 'b, 'c> { + pub fn handle( + item: Result<(CmdDetails, TLVElement), CmdStatus>, + handler: &mut T, + transaction: &mut Transaction, + tw: &mut TLVWriter, + ) -> Result<(), Error> { + let status = match item { + Ok((cmd, data)) => { + let mut tracker = CmdDataTracker::new(); + let encoder = CmdDataEncoder::new(&cmd, &mut tracker, tw); + + match handler.invoke(transaction, &cmd, &data, encoder) { + Ok(()) => cmd.success(&tracker), + Err(error) => cmd.status(error.into()), + } + } + Err(status) => Some(status), + }; + + if let Some(status) = status { + InvResp::Status(status).to_tlv(tw, TagType::Anonymous)?; + } + + Ok(()) + } + + #[cfg(feature = "nightly")] + pub async fn handle_async( + item: Result<(CmdDetails<'_>, TLVElement<'_>), CmdStatus>, + handler: &mut T, + transaction: &mut Transaction<'_, '_>, + tw: &mut TLVWriter<'_, '_>, + ) -> Result<(), Error> { + let status = match item { + Ok((cmd, data)) => { + let mut tracker = CmdDataTracker::new(); + let encoder = CmdDataEncoder::new(&cmd, &mut tracker, tw); + + match handler.invoke(transaction, &cmd, &data, encoder).await { + Ok(()) => cmd.success(&tracker), + Err(error) => cmd.status(error.into()), + } + } + Err(status) => Some(status), + }; + + if let Some(status) = status { + InvResp::Status(status).to_tlv(tw, TagType::Anonymous)?; + } + + Ok(()) + } + + pub fn new( + cmd: &CmdDetails, + tracker: &'a mut CmdDataTracker, + tw: &'a mut TLVWriter<'b, 'c>, + ) -> Self { + Self { + tracker, + path: cmd.path(), + tw, + } + } + + pub fn with_command(mut self, cmd: u16) -> Result, Error> { + let mut writer = CmdDataWriter::new(self.tracker, self.tw); + + writer.start_struct(TagType::Anonymous)?; + writer.start_struct(TagType::Context(InvRespTag::Cmd as _))?; + + self.path.path.leaf = Some(cmd as _); + self.path + .to_tlv(&mut writer, TagType::Context(CmdDataTag::Path as _))?; + + Ok(writer) + } +} + +pub struct CmdDataWriter<'a, 'b, 'c> { + tracker: &'a mut CmdDataTracker, + tw: &'a mut TLVWriter<'b, 'c>, + anchor: usize, + completed: bool, +} + +impl<'a, 'b, 'c> CmdDataWriter<'a, 'b, 'c> { + pub const TAG: TagType = TagType::Context(CmdDataTag::Data as _); + + fn new(tracker: &'a mut CmdDataTracker, tw: &'a mut TLVWriter<'b, 'c>) -> Self { + let anchor = tw.get_tail(); + + Self { + tracker, + tw, + anchor, + completed: false, + } + } + + pub fn set(self, value: T) -> Result<(), Error> { + value.to_tlv(self.tw, Self::TAG)?; + self.complete() + } + + pub fn complete(mut self) -> Result<(), Error> { + self.tw.end_container()?; + self.tw.end_container()?; + + self.completed = true; + self.tracker.complete(); + + Ok(()) + } + + fn reset(&mut self) { + self.tw.rewind_to(self.anchor); + } +} + +impl<'a, 'b, 'c> Drop for CmdDataWriter<'a, 'b, 'c> { + fn drop(&mut self) { + if !self.completed { + self.reset(); + } + } +} + +impl<'a, 'b, 'c> Deref for CmdDataWriter<'a, 'b, 'c> { + type Target = TLVWriter<'b, 'c>; + + fn deref(&self) -> &Self::Target { + self.tw + } +} + +impl<'a, 'b, 'c> DerefMut for CmdDataWriter<'a, 'b, 'c> { + fn deref_mut(&mut self) -> &mut Self::Target { + self.tw + } } -#[derive(ToTLV, Copy, Clone)] -pub struct DeviceType { - pub dtype: u16, - pub drev: u16, +#[derive(Copy, Clone, Debug)] +pub struct AttrType(PhantomData T>); + +impl AttrType { + pub const fn new() -> Self { + Self(PhantomData) + } + + pub fn encode(&self, writer: AttrDataWriter, value: T) -> Result<(), Error> + where + T: ToTLV, + { + writer.set(value) + } + + pub fn decode<'a>(&self, data: &'a TLVElement) -> Result + where + T: FromTLV<'a>, + { + T::from_tlv(data) + } +} + +impl Default for AttrType { + fn default() -> Self { + Self::new() + } +} + +#[derive(Copy, Clone, Debug, Default)] +pub struct AttrUtfType; + +impl AttrUtfType { + pub const fn new() -> Self { + Self + } + + pub fn encode(&self, writer: AttrDataWriter, value: &str) -> Result<(), Error> { + writer.set(UtfStr::new(value.as_bytes())) + } + + pub fn decode<'a>(&self, data: &'a TLVElement) -> Result<&'a str, IMStatusCode> { + data.str().map_err(|_| IMStatusCode::InvalidDataType) + } +} + +#[allow(unused_macros)] +#[macro_export] +macro_rules! attribute_enum { + ($en:ty) => { + impl core::convert::TryFrom<$crate::data_model::objects::AttrId> for $en { + type Error = $crate::error::Error; + + fn try_from(id: $crate::data_model::objects::AttrId) -> Result { + <$en>::from_repr(id).ok_or($crate::error::Error::AttributeNotFound) + } + } + }; +} + +#[allow(unused_macros)] +#[macro_export] +macro_rules! command_enum { + ($en:ty) => { + impl core::convert::TryFrom<$crate::data_model::objects::CmdId> for $en { + type Error = $crate::error::Error; + + fn try_from(id: $crate::data_model::objects::CmdId) -> Result { + <$en>::from_repr(id).ok_or($crate::error::Error::CommandNotFound) + } + } + }; } diff --git a/matter/src/data_model/objects/endpoint.rs b/matter/src/data_model/objects/endpoint.rs index 466e7a64..d0a4fddf 100644 --- a/matter/src/data_model/objects/endpoint.rs +++ b/matter/src/data_model/objects/endpoint.rs @@ -15,104 +15,91 @@ * limitations under the License. */ -use crate::{data_model::objects::ClusterType, error::*, interaction_model::core::IMStatusCode}; +use crate::{acl::Accessor, interaction_model::core::IMStatusCode}; -use std::fmt; +use core::fmt; -use super::{ClusterId, DeviceType}; +use super::{AttrId, Cluster, ClusterId, CmdId, DeviceType, EndptId}; -pub const CLUSTERS_PER_ENDPT: usize = 9; - -pub struct Endpoint { - dev_type: DeviceType, - clusters: Vec>, +#[derive(Debug, Clone)] +pub struct Endpoint<'a> { + pub id: EndptId, + pub device_type: DeviceType, + pub clusters: &'a [Cluster<'a>], } -pub type BoxedClusters = [Box]; - -impl Endpoint { - pub fn new(dev_type: DeviceType) -> Result, Error> { - Ok(Box::new(Endpoint { - dev_type, - clusters: Vec::with_capacity(CLUSTERS_PER_ENDPT), - })) - } - - pub fn add_cluster(&mut self, cluster: Box) -> Result<(), Error> { - if self.clusters.len() < self.clusters.capacity() { - self.clusters.push(cluster); - Ok(()) - } else { - Err(Error::NoSpace) - } - } - - pub fn get_dev_type(&self) -> &DeviceType { - &self.dev_type +impl<'a> Endpoint<'a> { + pub(crate) fn match_attributes<'m>( + &'m self, + accessor: &'m Accessor<'m>, + cl: Option, + attr: Option, + write: bool, + ) -> impl Iterator + 'm { + self.match_clusters(cl).flat_map(move |cluster| { + cluster + .match_attributes(accessor, self.id, attr, write) + .map(move |attr| (cluster.id, attr)) + }) } - fn get_cluster_index(&self, cluster_id: ClusterId) -> Option { - self.clusters.iter().position(|c| c.base().id == cluster_id) + pub(crate) fn match_commands<'m>( + &'m self, + accessor: &'m Accessor<'m>, + cl: Option, + cmd: Option, + ) -> impl Iterator + 'm { + self.match_clusters(cl).flat_map(move |cluster| { + cluster + .match_commands(accessor, self.id, cmd) + .map(move |cmd| (cluster.id, cmd)) + }) } - pub fn get_cluster(&self, cluster_id: ClusterId) -> Result<&dyn ClusterType, Error> { - let index = self - .get_cluster_index(cluster_id) - .ok_or(Error::ClusterNotFound)?; - Ok(self.clusters[index].as_ref()) + pub(crate) fn check_attribute( + &self, + accessor: &Accessor, + cl: ClusterId, + attr: AttrId, + write: bool, + ) -> Result<(), IMStatusCode> { + self.check_cluster(cl) + .and_then(|cluster| cluster.check_attribute(accessor, self.id, attr, write)) } - pub fn get_cluster_mut( - &mut self, - cluster_id: ClusterId, - ) -> Result<&mut dyn ClusterType, Error> { - let index = self - .get_cluster_index(cluster_id) - .ok_or(Error::ClusterNotFound)?; - Ok(self.clusters[index].as_mut()) + pub(crate) fn check_command( + &self, + accessor: &Accessor, + cl: ClusterId, + cmd: CmdId, + ) -> Result<(), IMStatusCode> { + self.check_cluster(cl) + .and_then(|cluster| cluster.check_command(accessor, self.id, cmd)) } - // Returns a slice of clusters, with either a single cluster or all (wildcard) - pub fn get_wildcard_clusters( - &self, - cluster: Option, - ) -> Result<(&BoxedClusters, bool), IMStatusCode> { - if let Some(c) = cluster { - if let Some(i) = self.get_cluster_index(c) { - Ok((&self.clusters[i..i + 1], false)) - } else { - Err(IMStatusCode::UnsupportedCluster) - } - } else { - Ok((self.clusters.as_slice(), true)) - } + fn match_clusters(&self, cl: Option) -> impl Iterator + '_ { + self.clusters + .iter() + .filter(move |cluster| cl.map(|id| id == cluster.id).unwrap_or(true)) } - // Returns a slice of clusters, with either a single cluster or all (wildcard) - pub fn get_wildcard_clusters_mut( - &mut self, - cluster: Option, - ) -> Result<(&mut BoxedClusters, bool), IMStatusCode> { - if let Some(c) = cluster { - if let Some(i) = self.get_cluster_index(c) { - Ok((&mut self.clusters[i..i + 1], false)) - } else { - Err(IMStatusCode::UnsupportedCluster) - } - } else { - Ok((&mut self.clusters[..], true)) - } + fn check_cluster(&self, cl: ClusterId) -> Result<&Cluster, IMStatusCode> { + self.clusters + .iter() + .find(|cluster| cluster.id == cl) + .ok_or(IMStatusCode::UnsupportedCluster) } } -impl std::fmt::Display for Endpoint { +impl<'a> core::fmt::Display for Endpoint<'a> { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { write!(f, "clusters:[")?; let mut comma = ""; - for element in self.clusters.iter() { - write!(f, "{} {{ {} }}", comma, element.base())?; + for cluster in self.clusters { + write!(f, "{} {{ {} }}", comma, cluster)?; comma = ", "; } + write!(f, "]") } } diff --git a/matter/src/data_model/objects/handler.rs b/matter/src/data_model/objects/handler.rs new file mode 100644 index 00000000..052d6906 --- /dev/null +++ b/matter/src/data_model/objects/handler.rs @@ -0,0 +1,350 @@ +/* + * + * Copyright (c) 2020-2022 Project CHIP Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +use crate::{error::Error, interaction_model::core::Transaction, tlv::TLVElement}; + +use super::{AttrData, AttrDataEncoder, AttrDetails, CmdDataEncoder, CmdDetails}; + +pub trait ChangeNotifier { + fn consume_change(&mut self) -> Option; +} + +pub trait Handler { + fn read(&self, attr: &AttrDetails, encoder: AttrDataEncoder) -> Result<(), Error>; + + fn write(&mut self, _attr: &AttrDetails, _data: AttrData) -> Result<(), Error> { + Err(Error::AttributeNotFound) + } + + fn invoke( + &mut self, + _transaction: &mut Transaction, + _cmd: &CmdDetails, + _data: &TLVElement, + _encoder: CmdDataEncoder, + ) -> Result<(), Error> { + Err(Error::CommandNotFound) + } +} + +impl Handler for &mut T +where + T: Handler, +{ + fn read<'a>(&self, attr: &AttrDetails, encoder: AttrDataEncoder) -> Result<(), Error> { + (**self).read(attr, encoder) + } + + fn write(&mut self, attr: &AttrDetails, data: AttrData) -> Result<(), Error> { + (**self).write(attr, data) + } + + fn invoke( + &mut self, + transaction: &mut Transaction, + cmd: &CmdDetails, + data: &TLVElement, + encoder: CmdDataEncoder, + ) -> Result<(), Error> { + (**self).invoke(transaction, cmd, data, encoder) + } +} + +pub trait NonBlockingHandler: Handler {} + +impl NonBlockingHandler for &mut T where T: NonBlockingHandler {} + +pub struct EmptyHandler; + +impl EmptyHandler { + pub const fn chain( + self, + handler_endpoint: u16, + handler_cluster: u32, + handler: H, + ) -> ChainedHandler { + ChainedHandler { + handler_endpoint, + handler_cluster, + handler, + next: self, + } + } +} + +impl Handler for EmptyHandler { + fn read(&self, _attr: &AttrDetails, _encoder: AttrDataEncoder) -> Result<(), Error> { + Err(Error::AttributeNotFound) + } +} + +impl NonBlockingHandler for EmptyHandler {} + +impl ChangeNotifier<(u16, u32)> for EmptyHandler { + fn consume_change(&mut self) -> Option<(u16, u32)> { + None + } +} + +pub struct ChainedHandler { + pub handler_endpoint: u16, + pub handler_cluster: u32, + pub handler: H, + pub next: T, +} + +impl ChainedHandler { + pub const fn chain

( + self, + handler_endpoint: u16, + handler_cluster: u32, + handler: H2, + ) -> ChainedHandler { + ChainedHandler { + handler_endpoint, + handler_cluster, + handler, + next: self, + } + } +} + +impl Handler for ChainedHandler +where + H: Handler, + T: Handler, +{ + fn read(&self, attr: &AttrDetails, encoder: AttrDataEncoder) -> Result<(), Error> { + if self.handler_endpoint == attr.endpoint_id && self.handler_cluster == attr.cluster_id { + self.handler.read(attr, encoder) + } else { + self.next.read(attr, encoder) + } + } + + fn write(&mut self, attr: &AttrDetails, data: AttrData) -> Result<(), Error> { + if self.handler_endpoint == attr.endpoint_id && self.handler_cluster == attr.cluster_id { + self.handler.write(attr, data) + } else { + self.next.write(attr, data) + } + } + + fn invoke( + &mut self, + transaction: &mut Transaction, + cmd: &CmdDetails, + data: &TLVElement, + encoder: CmdDataEncoder, + ) -> Result<(), Error> { + if self.handler_endpoint == cmd.endpoint_id && self.handler_cluster == cmd.cluster_id { + self.handler.invoke(transaction, cmd, data, encoder) + } else { + self.next.invoke(transaction, cmd, data, encoder) + } + } +} + +impl NonBlockingHandler for ChainedHandler +where + H: NonBlockingHandler, + T: NonBlockingHandler, +{ +} + +impl ChangeNotifier<(u16, u32)> for ChainedHandler +where + H: ChangeNotifier<()>, + T: ChangeNotifier<(u16, u32)>, +{ + fn consume_change(&mut self) -> Option<(u16, u32)> { + if self.handler.consume_change().is_some() { + Some((self.handler_endpoint, self.handler_cluster)) + } else { + self.next.consume_change() + } + } +} + +#[allow(unused_macros)] +#[macro_export] +macro_rules! handler_chain_type { + ($h:ty) => { + $crate::data_model::objects::ChainedHandler<$h, $crate::data_model::objects::EmptyHandler> + }; + ($h1:ty, $($rest:ty),+) => { + $crate::data_model::objects::ChainedHandler<$h1, handler_chain_type!($($rest),+)> + }; +} + +#[cfg(feature = "nightly")] +pub mod asynch { + use crate::{ + data_model::objects::{AttrData, AttrDataEncoder, AttrDetails, CmdDataEncoder, CmdDetails}, + error::Error, + interaction_model::core::Transaction, + tlv::TLVElement, + }; + + use super::{ChainedHandler, EmptyHandler, Handler, NonBlockingHandler}; + + pub trait AsyncHandler { + async fn read<'a>( + &'a self, + attr: &'a AttrDetails<'_>, + encoder: AttrDataEncoder<'a, '_, '_>, + ) -> Result<(), Error>; + + async fn write<'a>( + &'a mut self, + _attr: &'a AttrDetails<'_>, + _data: AttrData<'a>, + ) -> Result<(), Error> { + Err(Error::AttributeNotFound) + } + + async fn invoke<'a>( + &'a mut self, + _transaction: &'a mut Transaction<'_, '_>, + _cmd: &'a CmdDetails<'_>, + _data: &'a TLVElement<'_>, + _encoder: CmdDataEncoder<'a, '_, '_>, + ) -> Result<(), Error> { + Err(Error::CommandNotFound) + } + } + + impl AsyncHandler for &mut T + where + T: AsyncHandler, + { + async fn read<'a>( + &'a self, + attr: &'a AttrDetails<'_>, + encoder: AttrDataEncoder<'a, '_, '_>, + ) -> Result<(), Error> { + (**self).read(attr, encoder).await + } + + async fn write<'a>( + &'a mut self, + attr: &'a AttrDetails<'_>, + data: AttrData<'a>, + ) -> Result<(), Error> { + (**self).write(attr, data).await + } + + async fn invoke<'a>( + &'a mut self, + transaction: &'a mut Transaction<'_, '_>, + cmd: &'a CmdDetails<'_>, + data: &'a TLVElement<'_>, + encoder: CmdDataEncoder<'a, '_, '_>, + ) -> Result<(), Error> { + (**self).invoke(transaction, cmd, data, encoder).await + } + } + + pub struct Asyncify(pub T); + + impl AsyncHandler for Asyncify + where + T: NonBlockingHandler, + { + async fn read<'a>( + &'a self, + attr: &'a AttrDetails<'_>, + encoder: AttrDataEncoder<'a, '_, '_>, + ) -> Result<(), Error> { + Handler::read(&self.0, attr, encoder) + } + + async fn write<'a>( + &'a mut self, + attr: &'a AttrDetails<'_>, + data: AttrData<'a>, + ) -> Result<(), Error> { + Handler::write(&mut self.0, attr, data) + } + + async fn invoke<'a>( + &'a mut self, + transaction: &'a mut Transaction<'_, '_>, + cmd: &'a CmdDetails<'_>, + data: &'a TLVElement<'_>, + encoder: CmdDataEncoder<'a, '_, '_>, + ) -> Result<(), Error> { + Handler::invoke(&mut self.0, transaction, cmd, data, encoder) + } + } + + impl AsyncHandler for EmptyHandler { + async fn read<'a>( + &'a self, + _attr: &'a AttrDetails<'_>, + _encoder: AttrDataEncoder<'a, '_, '_>, + ) -> Result<(), Error> { + Err(Error::AttributeNotFound) + } + } + + impl AsyncHandler for ChainedHandler + where + H: AsyncHandler, + T: AsyncHandler, + { + async fn read<'a>( + &'a self, + attr: &'a AttrDetails<'_>, + encoder: AttrDataEncoder<'a, '_, '_>, + ) -> Result<(), Error> { + if self.handler_endpoint == attr.endpoint_id && self.handler_cluster == attr.cluster_id + { + self.handler.read(attr, encoder).await + } else { + self.next.read(attr, encoder).await + } + } + + async fn write<'a>( + &'a mut self, + attr: &'a AttrDetails<'_>, + data: AttrData<'a>, + ) -> Result<(), Error> { + if self.handler_endpoint == attr.endpoint_id && self.handler_cluster == attr.cluster_id + { + self.handler.write(attr, data).await + } else { + self.next.write(attr, data).await + } + } + + async fn invoke<'a>( + &'a mut self, + transaction: &'a mut Transaction<'_, '_>, + cmd: &'a CmdDetails<'_>, + data: &'a TLVElement<'_>, + encoder: CmdDataEncoder<'a, '_, '_>, + ) -> Result<(), Error> { + if self.handler_endpoint == cmd.endpoint_id && self.handler_cluster == cmd.cluster_id { + self.handler.invoke(transaction, cmd, data, encoder).await + } else { + self.next.invoke(transaction, cmd, data, encoder).await + } + } + } +} diff --git a/matter/src/data_model/objects/mod.rs b/matter/src/data_model/objects/mod.rs index 2fb3aff5..1bd326e4 100644 --- a/matter/src/data_model/objects/mod.rs +++ b/matter/src/data_model/objects/mod.rs @@ -14,11 +14,8 @@ * See the License for the specific language governing permissions and * limitations under the License. */ - -pub type EndptId = u16; -pub type ClusterId = u32; -pub type AttrId = u16; -pub type CmdId = u32; +use crate::error::Error; +use crate::tlv::{TLVWriter, TagType, ToTLV}; mod attribute; pub use attribute::*; @@ -37,3 +34,20 @@ pub use privilege::*; mod encoder; pub use encoder::*; + +mod handler; +pub use handler::*; + +mod dataver; +pub use dataver::*; + +pub type EndptId = u16; +pub type ClusterId = u32; +pub type AttrId = u16; +pub type CmdId = u32; + +#[derive(Debug, ToTLV, Copy, Clone)] +pub struct DeviceType { + pub dtype: u16, + pub drev: u16, +} diff --git a/matter/src/data_model/objects/node.rs b/matter/src/data_model/objects/node.rs index ba2f0b28..2eb11754 100644 --- a/matter/src/data_model/objects/node.rs +++ b/matter/src/data_model/objects/node.rs @@ -16,283 +16,379 @@ */ use crate::{ - data_model::objects::{ClusterType, Endpoint}, - error::*, - interaction_model::{core::IMStatusCode, messages::GenericPath}, + acl::Accessor, + data_model::objects::Endpoint, + interaction_model::{ + core::IMStatusCode, + messages::{ + ib::{AttrStatus, CmdStatus}, + msg::{InvReq, ReadReq, WriteReq}, + GenericPath, + }, + }, // TODO: This layer shouldn't really depend on the TLV layer, should create an abstraction layer + tlv::TLVElement, +}; +use core::{ + fmt, + iter::{once, Once}, }; -use std::fmt; -use super::{ClusterId, DeviceType, EndptId}; +use super::{AttrDetails, AttrId, ClusterId, CmdDetails, CmdId, EndptId}; -pub trait ChangeConsumer { - fn endpoint_added(&self, id: EndptId, endpoint: &mut Endpoint) -> Result<(), Error>; +enum WildcardIter { + None, + Single(Once), + Wildcard(T), } -pub const ENDPTS_PER_ACC: usize = 3; +impl Iterator for WildcardIter +where + T: Iterator, +{ + type Item = E; -pub type BoxedEndpoints = [Option>]; + fn next(&mut self) -> Option { + match self { + Self::None => None, + Self::Single(iter) => iter.next(), + Self::Wildcard(iter) => iter.next(), + } + } +} -#[derive(Default)] -pub struct Node { - endpoints: [Option>; ENDPTS_PER_ACC], - changes_cb: Option>, +#[derive(Debug, Clone)] +pub struct Node<'a> { + pub id: u16, + pub endpoints: &'a [Endpoint<'a>], } -impl std::fmt::Display for Node { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - writeln!(f, "node:")?; - for (i, element) in self.endpoints.iter().enumerate() { - if let Some(e) = element { - writeln!(f, "endpoint {}: {}", i, e)?; - } +impl<'a> Node<'a> { + pub fn read<'s, 'm>( + &'s self, + req: &'m ReadReq, + accessor: &'m Accessor<'m>, + ) -> impl Iterator> + 'm + where + 's: 'm, + { + if let Some(attr_requests) = req.attr_requests.as_ref() { + WildcardIter::Wildcard(attr_requests.iter().flat_map( + move |path| match self.expand_attr(accessor, path.to_gp(), false) { + Ok(iter) => { + let wildcard = matches!(iter, WildcardIter::Wildcard(_)); + + WildcardIter::Wildcard(iter.map(move |(ep, cl, attr)| { + let dataver_filter = req + .dataver_filters + .as_ref() + .iter() + .flat_map(|array| array.iter()) + .find_map(|filter| { + (filter.path.endpoint == ep && filter.path.cluster == cl) + .then_some(filter.data_ver) + }); + + Ok(AttrDetails { + node: self, + endpoint_id: ep, + cluster_id: cl, + attr_id: attr, + list_index: path.list_index, + fab_idx: accessor.fab_idx, + fab_filter: req.fabric_filtered, + dataver: dataver_filter, + wildcard, + }) + })) + } + Err(err) => { + WildcardIter::Single(once(Err(AttrStatus::new(&path.to_gp(), err, 0)))) + } + }, + )) + } else { + WildcardIter::None } - write!(f, "") } -} -impl Node { - pub fn new() -> Result, Error> { - let node = Box::default(); - Ok(node) - } + pub fn write<'m>( + &'m self, + req: &'m WriteReq, + accessor: &'m Accessor<'m>, + ) -> impl Iterator), AttrStatus>> + 'm { + req.write_requests.iter().flat_map(move |attr_data| { + if attr_data.path.cluster.is_none() { + WildcardIter::Single(once(Err(AttrStatus::new( + &attr_data.path.to_gp(), + IMStatusCode::UnsupportedCluster, + 0, + )))) + } else if attr_data.path.attr.is_none() { + WildcardIter::Single(once(Err(AttrStatus::new( + &attr_data.path.to_gp(), + IMStatusCode::UnsupportedAttribute, + 0, + )))) + } else { + match self.expand_attr(accessor, attr_data.path.to_gp(), true) { + Ok(iter) => { + let wildcard = matches!(iter, WildcardIter::Wildcard(_)); - pub fn set_changes_cb(&mut self, consumer: Box) { - self.changes_cb = Some(consumer); + WildcardIter::Wildcard(iter.map(move |(ep, cl, attr)| { + Ok(( + AttrDetails { + node: self, + endpoint_id: ep, + cluster_id: cl, + attr_id: attr, + list_index: attr_data.path.list_index, + fab_idx: accessor.fab_idx, + fab_filter: false, + dataver: attr_data.data_ver, + wildcard, + }, + attr_data.data.unwrap_tlv().unwrap(), + )) + })) + } + Err(err) => WildcardIter::Single(once(Err(AttrStatus::new( + &attr_data.path.to_gp(), + err, + 0, + )))), + } + } + }) } - pub fn add_endpoint(&mut self, dev_type: DeviceType) -> Result { - let index = self - .endpoints - .iter() - .position(|x| x.is_none()) - .ok_or(Error::NoSpace)?; - let mut endpoint = Endpoint::new(dev_type)?; - if let Some(cb) = &self.changes_cb { - cb.endpoint_added(index as EndptId, &mut endpoint)?; + pub fn invoke<'m>( + &'m self, + req: &'m InvReq, + accessor: &'m Accessor<'m>, + ) -> impl Iterator), CmdStatus>> + 'm { + if let Some(inv_requests) = req.inv_requests.as_ref() { + WildcardIter::Wildcard(inv_requests.iter().flat_map(move |cmd_data| { + match self.expand_cmd(accessor, cmd_data.path.path) { + Ok(iter) => { + let wildcard = matches!(iter, WildcardIter::Wildcard(_)); + + WildcardIter::Wildcard(iter.map(move |(ep, cl, cmd)| { + Ok(( + CmdDetails { + node: self, + endpoint_id: ep, + cluster_id: cl, + cmd_id: cmd, + wildcard, + }, + cmd_data.data.unwrap_tlv().unwrap(), + )) + })) + } + Err(err) => { + WildcardIter::Single(once(Err(CmdStatus::new(cmd_data.path, err, 0)))) + } + } + })) + } else { + WildcardIter::None } - self.endpoints[index] = Some(endpoint); - Ok(index as EndptId) } - pub fn get_endpoint(&self, endpoint_id: EndptId) -> Result<&Endpoint, Error> { - if (endpoint_id as usize) < ENDPTS_PER_ACC { - let endpoint = self.endpoints[endpoint_id as usize] - .as_ref() - .ok_or(Error::EndpointNotFound)?; - Ok(endpoint) + fn expand_attr<'m>( + &'m self, + accessor: &'m Accessor<'m>, + path: GenericPath, + write: bool, + ) -> Result< + WildcardIter< + impl Iterator + 'm, + (EndptId, ClusterId, AttrId), + >, + IMStatusCode, + > { + if path.is_wildcard() { + Ok(WildcardIter::Wildcard(self.match_attributes( + accessor, + path.endpoint, + path.cluster, + path.leaf.map(|leaf| leaf as u16), + write, + ))) } else { - Err(Error::EndpointNotFound) + self.check_attribute( + accessor, + path.endpoint.unwrap(), + path.cluster.unwrap(), + path.leaf.unwrap() as _, + write, + )?; + + Ok(WildcardIter::Single(once(( + path.endpoint.unwrap(), + path.cluster.unwrap(), + path.leaf.unwrap() as _, + )))) } } - pub fn get_endpoint_mut(&mut self, endpoint_id: EndptId) -> Result<&mut Endpoint, Error> { - if (endpoint_id as usize) < ENDPTS_PER_ACC { - let endpoint = self.endpoints[endpoint_id as usize] - .as_mut() - .ok_or(Error::EndpointNotFound)?; - Ok(endpoint) + fn expand_cmd<'m>( + &'m self, + accessor: &'m Accessor<'m>, + path: GenericPath, + ) -> Result< + WildcardIter< + impl Iterator + 'm, + (EndptId, ClusterId, CmdId), + >, + IMStatusCode, + > { + if path.is_wildcard() { + Ok(WildcardIter::Wildcard(self.match_commands( + accessor, + path.endpoint, + path.cluster, + path.leaf, + ))) } else { - Err(Error::EndpointNotFound) + self.check_command( + accessor, + path.endpoint.unwrap(), + path.cluster.unwrap(), + path.leaf.unwrap(), + )?; + + Ok(WildcardIter::Single(once(( + path.endpoint.unwrap(), + path.cluster.unwrap(), + path.leaf.unwrap(), + )))) } } - pub fn get_cluster_mut( - &mut self, - e: EndptId, - c: ClusterId, - ) -> Result<&mut dyn ClusterType, Error> { - self.get_endpoint_mut(e)?.get_cluster_mut(c) + fn match_attributes<'m>( + &'m self, + accessor: &'m Accessor<'m>, + ep: Option, + cl: Option, + attr: Option, + write: bool, + ) -> impl Iterator + 'm { + self.match_endpoints(ep).flat_map(move |endpoint| { + endpoint + .match_attributes(accessor, cl, attr, write) + .map(move |(cl, attr)| (endpoint.id, cl, attr)) + }) } - pub fn get_cluster(&self, e: EndptId, c: ClusterId) -> Result<&dyn ClusterType, Error> { - self.get_endpoint(e)?.get_cluster(c) + fn match_commands<'m>( + &'m self, + accessor: &'m Accessor<'m>, + ep: Option, + cl: Option, + cmd: Option, + ) -> impl Iterator + 'm { + self.match_endpoints(ep).flat_map(move |endpoint| { + endpoint + .match_commands(accessor, cl, cmd) + .map(move |(cl, cmd)| (endpoint.id, cl, cmd)) + }) } - pub fn add_cluster( - &mut self, - endpoint_id: EndptId, - cluster: Box, - ) -> Result<(), Error> { - let endpoint_id = endpoint_id as usize; - if endpoint_id < ENDPTS_PER_ACC { - self.endpoints[endpoint_id] - .as_mut() - .ok_or(Error::NoEndpoint)? - .add_cluster(cluster) - } else { - Err(Error::Invalid) - } + fn check_attribute( + &self, + accessor: &Accessor, + ep: EndptId, + cl: ClusterId, + attr: AttrId, + write: bool, + ) -> Result<(), IMStatusCode> { + self.check_endpoint(ep) + .and_then(|endpoint| endpoint.check_attribute(accessor, cl, attr, write)) } - // Returns a slice of endpoints, with either a single endpoint or all (wildcard) - pub fn get_wildcard_endpoints( + fn check_command( &self, - endpoint: Option, - ) -> Result<(&BoxedEndpoints, usize, bool), IMStatusCode> { - if let Some(e) = endpoint { - let e = e as usize; - if self.endpoints.len() <= e || self.endpoints[e].is_none() { - Err(IMStatusCode::UnsupportedEndpoint) - } else { - Ok((&self.endpoints[e..e + 1], e, false)) - } - } else { - Ok((&self.endpoints[..], 0, true)) - } + accessor: &Accessor, + ep: EndptId, + cl: ClusterId, + cmd: CmdId, + ) -> Result<(), IMStatusCode> { + self.check_endpoint(ep) + .and_then(|endpoint| endpoint.check_command(accessor, cl, cmd)) } - pub fn get_wildcard_endpoints_mut( - &mut self, - endpoint: Option, - ) -> Result<(&mut BoxedEndpoints, usize, bool), IMStatusCode> { - if let Some(e) = endpoint { - let e = e as usize; - if self.endpoints.len() <= e || self.endpoints[e].is_none() { - Err(IMStatusCode::UnsupportedEndpoint) - } else { - Ok((&mut self.endpoints[e..e + 1], e, false)) - } - } else { - Ok((&mut self.endpoints[..], 0, true)) + fn match_endpoints(&self, ep: Option) -> impl Iterator + '_ { + self.endpoints + .iter() + .filter(move |endpoint| ep.map(|id| id == endpoint.id).unwrap_or(true)) + } + + fn check_endpoint(&self, ep: EndptId) -> Result<&Endpoint, IMStatusCode> { + self.endpoints + .iter() + .find(|endpoint| endpoint.id == ep) + .ok_or(IMStatusCode::UnsupportedEndpoint) + } +} + +impl<'a> core::fmt::Display for Node<'a> { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + writeln!(f, "node:")?; + for (index, endpoint) in self.endpoints.iter().enumerate() { + writeln!(f, "endpoint {}: {}", index, endpoint)?; } + + write!(f, "") } +} - /// Run a closure for all endpoints as specified in the path - /// - /// Note that the path is a GenericPath and hence can be a wildcard path. The behaviour - /// of this function is to only capture the successful invocations and ignore the erroneous - /// ones. This is inline with the expected behaviour for wildcard, where it implies that - /// 'please run this operation on this wildcard path "wherever possible"' - /// - /// It is expected that if the closure that you pass here returns an error it may not reach - /// out to the caller, in case there was a wildcard path specified - pub fn for_each_endpoint(&self, path: &GenericPath, mut f: T) -> Result<(), IMStatusCode> - where - T: FnMut(&GenericPath, &Endpoint) -> Result<(), IMStatusCode>, - { - let mut current_path = *path; - let (endpoints, mut endpoint_id, wildcard) = self.get_wildcard_endpoints(path.endpoint)?; - for e in endpoints.iter() { - if let Some(e) = e { - current_path.endpoint = Some(endpoint_id as EndptId); - f(¤t_path, e.as_ref()) - .or_else(|e| if !wildcard { Err(e) } else { Ok(()) })?; - } - endpoint_id += 1; +pub struct DynamicNode<'a, const N: usize> { + id: u16, + endpoints: heapless::Vec, N>, +} + +impl<'a, const N: usize> DynamicNode<'a, N> { + pub const fn new(id: u16) -> Self { + Self { + id, + endpoints: heapless::Vec::new(), } - Ok(()) } - /// Run a closure for all endpoints (mutable) as specified in the path - /// - /// Note that the path is a GenericPath and hence can be a wildcard path. The behaviour - /// of this function is to only capture the successful invocations and ignore the erroneous - /// ones. This is inline with the expected behaviour for wildcard, where it implies that - /// 'please run this operation on this wildcard path "wherever possible"' - /// - /// It is expected that if the closure that you pass here returns an error it may not reach - /// out to the caller, in case there was a wildcard path specified - pub fn for_each_endpoint_mut( - &mut self, - path: &GenericPath, - mut f: T, - ) -> Result<(), IMStatusCode> - where - T: FnMut(&GenericPath, &mut Endpoint) -> Result<(), IMStatusCode>, - { - let mut current_path = *path; - let (endpoints, mut endpoint_id, wildcard) = - self.get_wildcard_endpoints_mut(path.endpoint)?; - for e in endpoints.iter_mut() { - if let Some(e) = e { - current_path.endpoint = Some(endpoint_id as EndptId); - f(¤t_path, e.as_mut()) - .or_else(|e| if !wildcard { Err(e) } else { Ok(()) })?; - } - endpoint_id += 1; + pub fn node(&self) -> Node<'_> { + Node { + id: self.id, + endpoints: &self.endpoints, } - Ok(()) } - /// Run a closure for all clusters as specified in the path - /// - /// Note that the path is a GenericPath and hence can be a wildcard path. The behaviour - /// of this function is to only capture the successful invocations and ignore the erroneous - /// ones. This is inline with the expected behaviour for wildcard, where it implies that - /// 'please run this operation on this wildcard path "wherever possible"' - /// - /// It is expected that if the closure that you pass here returns an error it may not reach - /// out to the caller, in case there was a wildcard path specified - pub fn for_each_cluster(&self, path: &GenericPath, mut f: T) -> Result<(), IMStatusCode> - where - T: FnMut(&GenericPath, &dyn ClusterType) -> Result<(), IMStatusCode>, - { - self.for_each_endpoint(path, |p, e| { - let mut current_path = *p; - let (clusters, wildcard) = e.get_wildcard_clusters(p.cluster)?; - for c in clusters.iter() { - current_path.cluster = Some(c.base().id); - f(¤t_path, c.as_ref()) - .or_else(|e| if !wildcard { Err(e) } else { Ok(()) })?; - } - Ok(()) - }) + pub fn add(&mut self, endpoint: Endpoint<'a>) -> Result<(), Endpoint<'a>> { + if !self.endpoints.iter().any(|ep| ep.id == endpoint.id) { + self.endpoints.push(endpoint) + } else { + Err(endpoint) + } } - /// Run a closure for all clusters (mutable) as specified in the path - /// - /// Note that the path is a GenericPath and hence can be a wildcard path. The behaviour - /// of this function is to only capture the successful invocations and ignore the erroneous - /// ones. This is inline with the expected behaviour for wildcard, where it implies that - /// 'please run this operation on this wildcard path "wherever possible"' - /// - /// It is expected that if the closure that you pass here returns an error it may not reach - /// out to the caller, in case there was a wildcard path specified - pub fn for_each_cluster_mut( - &mut self, - path: &GenericPath, - mut f: T, - ) -> Result<(), IMStatusCode> - where - T: FnMut(&GenericPath, &mut dyn ClusterType) -> Result<(), IMStatusCode>, - { - self.for_each_endpoint_mut(path, |p, e| { - let mut current_path = *p; - let (clusters, wildcard) = e.get_wildcard_clusters_mut(p.cluster)?; - - for c in clusters.iter_mut() { - current_path.cluster = Some(c.base().id); - f(¤t_path, c.as_mut()) - .or_else(|e| if !wildcard { Err(e) } else { Ok(()) })?; - } - Ok(()) - }) + pub fn remove(&mut self, endpoint_id: u16) -> Option> { + let index = self + .endpoints + .iter() + .enumerate() + .find_map(|(index, ep)| (ep.id == endpoint_id).then_some(index)); + + if let Some(index) = index { + Some(self.endpoints.swap_remove(index)) + } else { + None + } } +} - /// Run a closure for all attributes as specified in the path - /// - /// Note that the path is a GenericPath and hence can be a wildcard path. The behaviour - /// of this function is to only capture the successful invocations and ignore the erroneous - /// ones. This is inline with the expected behaviour for wildcard, where it implies that - /// 'please run this operation on this wildcard path "wherever possible"' - /// - /// It is expected that if the closure that you pass here returns an error it may not reach - /// out to the caller, in case there was a wildcard path specified - pub fn for_each_attribute(&self, path: &GenericPath, mut f: T) -> Result<(), IMStatusCode> - where - T: FnMut(&GenericPath, &dyn ClusterType) -> Result<(), IMStatusCode>, - { - self.for_each_cluster(path, |current_path, c| { - let mut current_path = *current_path; - let (attributes, wildcard) = c - .base() - .get_wildcard_attribute(path.leaf.map(|at| at as u16))?; - for a in attributes.iter() { - current_path.leaf = Some(a.id as u32); - f(¤t_path, c).or_else(|e| if !wildcard { Err(e) } else { Ok(()) })?; - } - Ok(()) - }) +impl<'a, const N: usize> core::fmt::Display for DynamicNode<'a, N> { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + self.node().fmt(f) } } diff --git a/matter/src/data_model/root_endpoint.rs b/matter/src/data_model/root_endpoint.rs new file mode 100644 index 00000000..44131b9f --- /dev/null +++ b/matter/src/data_model/root_endpoint.rs @@ -0,0 +1,108 @@ +use core::{borrow::Borrow, cell::RefCell}; + +use crate::{ + acl::AclMgr, + fabric::FabricMgr, + handler_chain_type, + mdns::MdnsMgr, + secure_channel::pake::PaseMgr, + utils::{epoch::Epoch, rand::Rand}, + Matter, +}; + +use super::{ + cluster_basic_information::{self, BasicInfoCluster, BasicInfoConfig}, + objects::{Cluster, EmptyHandler}, + sdm::{ + admin_commissioning::{self, AdminCommCluster}, + dev_att::DevAttDataFetcher, + failsafe::FailSafe, + general_commissioning::{self, GenCommCluster}, + noc::{self, NocCluster}, + nw_commissioning::{self, NwCommCluster}, + }, + system_model::access_control::{self, AccessControlCluster}, +}; + +pub type RootEndpointHandler<'a> = handler_chain_type!( + AccessControlCluster<'a>, + NocCluster<'a>, + AdminCommCluster<'a>, + NwCommCluster, + GenCommCluster, + BasicInfoCluster<'a> +); + +pub const CLUSTERS: [Cluster<'static>; 6] = [ + cluster_basic_information::CLUSTER, + general_commissioning::CLUSTER, + nw_commissioning::CLUSTER, + admin_commissioning::CLUSTER, + noc::CLUSTER, + access_control::CLUSTER, +]; + +pub fn handler<'a>( + endpoint_id: u16, + dev_att: &'a dyn DevAttDataFetcher, + matter: &'a Matter<'a>, +) -> RootEndpointHandler<'a> { + wrap( + endpoint_id, + matter.dev_det(), + dev_att, + matter.borrow(), + matter.borrow(), + matter.borrow(), + matter.borrow(), + matter.borrow(), + *matter.borrow(), + *matter.borrow(), + ) +} + +#[allow(clippy::too_many_arguments)] +pub fn wrap<'a>( + endpoint_id: u16, + basic_info: &'a BasicInfoConfig<'a>, + dev_att: &'a dyn DevAttDataFetcher, + pase: &'a RefCell, + fabric: &'a RefCell, + acl: &'a RefCell, + failsafe: &'a RefCell, + mdns_mgr: &'a RefCell>, + epoch: Epoch, + rand: Rand, +) -> RootEndpointHandler<'a> { + EmptyHandler + .chain( + endpoint_id, + cluster_basic_information::CLUSTER.id, + BasicInfoCluster::new(basic_info, rand), + ) + .chain( + endpoint_id, + general_commissioning::CLUSTER.id, + GenCommCluster::new(rand), + ) + .chain( + endpoint_id, + nw_commissioning::CLUSTER.id, + NwCommCluster::new(rand), + ) + .chain( + endpoint_id, + admin_commissioning::CLUSTER.id, + AdminCommCluster::new(pase, mdns_mgr, rand), + ) + .chain( + endpoint_id, + noc::CLUSTER.id, + NocCluster::new(dev_att, fabric, acl, failsafe, mdns_mgr, epoch, rand), + ) + .chain( + endpoint_id, + access_control::CLUSTER.id, + AccessControlCluster::new(acl, rand), + ) +} diff --git a/matter/src/data_model/sdm/admin_commissioning.rs b/matter/src/data_model/sdm/admin_commissioning.rs index fb317229..5497426b 100644 --- a/matter/src/data_model/sdm/admin_commissioning.rs +++ b/matter/src/data_model/sdm/admin_commissioning.rs @@ -15,15 +15,21 @@ * limitations under the License. */ -use crate::cmd_enter; +use core::cell::RefCell; +use core::convert::TryInto; + use crate::data_model::objects::*; -use crate::interaction_model::core::IMStatusCode; +use crate::interaction_model::core::Transaction; +use crate::mdns::MdnsMgr; use crate::secure_channel::pake::PaseMgr; use crate::secure_channel::spake2p::VerifierData; use crate::tlv::{FromTLV, Nullable, OctetStr, TLVElement}; -use crate::{error::*, interaction_model::command::CommandReq}; -use log::{error, info}; +use crate::utils::rand::Rand; +use crate::{attribute_enum, cmd_enter}; +use crate::{command_enum, error::*}; +use log::info; use num_derive::FromPrimitive; +use strum::{EnumDiscriminants, FromRepr}; pub const ID: u32 = 0x003C; @@ -34,127 +40,152 @@ pub enum WindowStatus { BasicWindowOpen = 2, } -#[derive(FromPrimitive)] +#[derive(Copy, Clone, Debug, FromRepr, EnumDiscriminants)] +#[repr(u16)] pub enum Attributes { - WindowStatus = 0, - AdminFabricIndex = 1, - AdminVendorId = 2, + WindowStatus(AttrType) = 0, + AdminFabricIndex(AttrType>) = 1, + AdminVendorId(AttrType>) = 2, } -#[derive(FromPrimitive)] +attribute_enum!(Attributes); + +#[derive(FromRepr)] +#[repr(u32)] pub enum Commands { OpenCommWindow = 0x00, OpenBasicCommWindow = 0x01, RevokeComm = 0x02, } -fn attr_window_status_new() -> Attribute { - Attribute::new( - Attributes::WindowStatus as u16, - AttrValue::Custom, - Access::RV, - Quality::NONE, - ) -} +command_enum!(Commands); -fn attr_admin_fabid_new() -> Attribute { - Attribute::new( - Attributes::AdminFabricIndex as u16, - AttrValue::Custom, - Access::RV, - Quality::NULLABLE, - ) -} +pub const CLUSTER: Cluster<'static> = Cluster { + id: ID as _, + feature_map: 0, + attributes: &[ + FEATURE_MAP, + ATTRIBUTE_LIST, + Attribute::new( + AttributesDiscriminants::WindowStatus as u16, + Access::RV, + Quality::NONE, + ), + Attribute::new( + AttributesDiscriminants::AdminFabricIndex as u16, + Access::RV, + Quality::NULLABLE, + ), + Attribute::new( + AttributesDiscriminants::AdminVendorId as u16, + Access::RV, + Quality::NULLABLE, + ), + ], + commands: &[ + Commands::OpenCommWindow as _, + Commands::OpenBasicCommWindow as _, + Commands::RevokeComm as _, + ], +}; -fn attr_admin_vid_new() -> Attribute { - Attribute::new( - Attributes::AdminVendorId as u16, - AttrValue::Custom, - Access::RV, - Quality::NULLABLE, - ) +#[derive(FromTLV)] +#[tlvargs(lifetime = "'a")] +pub struct OpenCommWindowReq<'a> { + _timeout: u16, + verifier: OctetStr<'a>, + discriminator: u16, + iterations: u32, + salt: OctetStr<'a>, } -pub struct AdminCommCluster { - pase_mgr: PaseMgr, - base: Cluster, +pub struct AdminCommCluster<'a> { + data_ver: Dataver, + pase_mgr: &'a RefCell, + mdns_mgr: &'a RefCell>, } -impl ClusterType for AdminCommCluster { - fn base(&self) -> &Cluster { - &self.base - } - fn base_mut(&mut self) -> &mut Cluster { - &mut self.base +impl<'a> AdminCommCluster<'a> { + pub fn new( + pase_mgr: &'a RefCell, + mdns_mgr: &'a RefCell>, + rand: Rand, + ) -> Self { + Self { + data_ver: Dataver::new(rand), + pase_mgr, + mdns_mgr, + } } - fn read_custom_attribute(&self, encoder: &mut dyn Encoder, attr: &AttrDetails) { - match num::FromPrimitive::from_u16(attr.attr_id) { - Some(Attributes::WindowStatus) => { - let status = 1_u8; - encoder.encode(EncodeValue::Value(&status)) - } - Some(Attributes::AdminVendorId) => { - let vid = Nullable::NotNull(1_u8); - - encoder.encode(EncodeValue::Value(&vid)) - } - Some(Attributes::AdminFabricIndex) => { - let vid = Nullable::NotNull(1_u8); - encoder.encode(EncodeValue::Value(&vid)) - } - _ => { - error!("Unsupported Attribute: this shouldn't happen"); + pub fn read(&self, attr: &AttrDetails, encoder: AttrDataEncoder) -> Result<(), Error> { + if let Some(writer) = encoder.with_dataver(self.data_ver.get())? { + if attr.is_system() { + CLUSTER.read(attr.attr_id, writer) + } else { + match attr.attr_id.try_into()? { + Attributes::WindowStatus(codec) => codec.encode(writer, 1), + Attributes::AdminVendorId(codec) => codec.encode(writer, Nullable::NotNull(1)), + Attributes::AdminFabricIndex(codec) => { + codec.encode(writer, Nullable::NotNull(1)) + } + } } + } else { + Ok(()) } } - fn handle_command(&mut self, cmd_req: &mut CommandReq) -> Result<(), IMStatusCode> { - let cmd = cmd_req - .cmd - .path - .leaf - .map(num::FromPrimitive::from_u32) - .ok_or(IMStatusCode::UnsupportedCommand)? - .ok_or(IMStatusCode::UnsupportedCommand)?; - match cmd { - Commands::OpenCommWindow => self.handle_command_opencomm_win(cmd_req), - _ => Err(IMStatusCode::UnsupportedCommand), + + pub fn invoke( + &mut self, + cmd: &CmdDetails, + data: &TLVElement, + _encoder: CmdDataEncoder, + ) -> Result<(), Error> { + match cmd.cmd_id.try_into()? { + Commands::OpenCommWindow => self.handle_command_opencomm_win(data)?, + _ => Err(Error::CommandNotFound)?, } + + self.data_ver.changed(); + + Ok(()) + } + + fn handle_command_opencomm_win(&mut self, data: &TLVElement) -> Result<(), Error> { + cmd_enter!("Open Commissioning Window"); + let req = OpenCommWindowReq::from_tlv(data)?; + let verifier = VerifierData::new(req.verifier.0, req.iterations, req.salt.0); + self.pase_mgr.borrow_mut().enable_pase_session( + verifier, + req.discriminator, + &mut self.mdns_mgr.borrow_mut(), + )?; + + Ok(()) } } -impl AdminCommCluster { - pub fn new(pase_mgr: PaseMgr) -> Result, Error> { - let mut c = Box::new(AdminCommCluster { - pase_mgr, - base: Cluster::new(ID)?, - }); - c.base.add_attribute(attr_window_status_new())?; - c.base.add_attribute(attr_admin_fabid_new())?; - c.base.add_attribute(attr_admin_vid_new())?; - Ok(c) +impl<'a> Handler for AdminCommCluster<'a> { + fn read(&self, attr: &AttrDetails, encoder: AttrDataEncoder) -> Result<(), Error> { + AdminCommCluster::read(self, attr, encoder) } - fn handle_command_opencomm_win( + fn invoke( &mut self, - cmd_req: &mut CommandReq, - ) -> Result<(), IMStatusCode> { - cmd_enter!("Open Commissioning Window"); - let req = - OpenCommWindowReq::from_tlv(&cmd_req.data).map_err(|_| IMStatusCode::InvalidCommand)?; - let verifier = VerifierData::new(req.verifier.0, req.iterations, req.salt.0); - self.pase_mgr - .enable_pase_session(verifier, req.discriminator)?; - Err(IMStatusCode::Success) + _transaction: &mut Transaction, + cmd: &CmdDetails, + data: &TLVElement, + encoder: CmdDataEncoder, + ) -> Result<(), Error> { + AdminCommCluster::invoke(self, cmd, data, encoder) } } -#[derive(FromTLV)] -#[tlvargs(lifetime = "'a")] -pub struct OpenCommWindowReq<'a> { - _timeout: u16, - verifier: OctetStr<'a>, - discriminator: u16, - iterations: u32, - salt: OctetStr<'a>, +impl<'a> NonBlockingHandler for AdminCommCluster<'a> {} + +impl<'a> ChangeNotifier<()> for AdminCommCluster<'a> { + fn consume_change(&mut self) -> Option<()> { + self.data_ver.consume_change(()) + } } diff --git a/matter/src/data_model/sdm/failsafe.rs b/matter/src/data_model/sdm/failsafe.rs index cd3c2be0..54b22e6a 100644 --- a/matter/src/data_model/sdm/failsafe.rs +++ b/matter/src/data_model/sdm/failsafe.rs @@ -17,7 +17,6 @@ use crate::{error::Error, transport::session::SessionMode}; use log::error; -use std::sync::RwLock; #[derive(PartialEq)] #[allow(dead_code)] @@ -42,26 +41,19 @@ pub enum State { Armed(ArmedCtx), } -pub struct FailSafeInner { - state: State, -} - pub struct FailSafe { - state: RwLock, + state: State, } impl FailSafe { - pub fn new() -> Self { - Self { - state: RwLock::new(FailSafeInner { state: State::Idle }), - } + pub const fn new() -> Self { + Self { state: State::Idle } } - pub fn arm(&self, timeout: u8, session_mode: SessionMode) -> Result<(), Error> { - let mut inner = self.state.write()?; - match &mut inner.state { + pub fn arm(&mut self, timeout: u8, session_mode: SessionMode) -> Result<(), Error> { + match &mut self.state { State::Idle => { - inner.state = State::Armed(ArmedCtx { + self.state = State::Armed(ArmedCtx { session_mode, timeout, noc_state: NocState::NocNotRecvd, @@ -78,9 +70,8 @@ impl FailSafe { Ok(()) } - pub fn disarm(&self, session_mode: SessionMode) -> Result<(), Error> { - let mut inner = self.state.write()?; - match &mut inner.state { + pub fn disarm(&mut self, session_mode: SessionMode) -> Result<(), Error> { + match &mut self.state { State::Idle => { error!("Received Fail-Safe Disarm without it being armed"); return Err(Error::Invalid); @@ -102,19 +93,18 @@ impl FailSafe { } } } - inner.state = State::Idle; + self.state = State::Idle; } } Ok(()) } pub fn is_armed(&self) -> bool { - self.state.read().unwrap().state != State::Idle + self.state != State::Idle } - pub fn record_add_noc(&self, fabric_index: u8) -> Result<(), Error> { - let mut inner = self.state.write()?; - match &mut inner.state { + pub fn record_add_noc(&mut self, fabric_index: u8) -> Result<(), Error> { + match &mut self.state { State::Idle => Err(Error::Invalid), State::Armed(c) => { if c.noc_state == NocState::NocNotRecvd { @@ -128,8 +118,7 @@ impl FailSafe { } pub fn allow_noc_change(&self) -> Result { - let mut inner = self.state.write()?; - let allow = match &mut inner.state { + let allow = match &self.state { State::Idle => false, State::Armed(c) => c.noc_state == NocState::NocNotRecvd, }; diff --git a/matter/src/data_model/sdm/general_commissioning.rs b/matter/src/data_model/sdm/general_commissioning.rs index 0328b213..aea37c7a 100644 --- a/matter/src/data_model/sdm/general_commissioning.rs +++ b/matter/src/data_model/sdm/general_commissioning.rs @@ -15,16 +15,19 @@ * limitations under the License. */ -use crate::cmd_enter; +use core::cell::RefCell; +use core::convert::TryInto; + use crate::data_model::objects::*; use crate::data_model::sdm::failsafe::FailSafe; -use crate::interaction_model::core::IMStatusCode; -use crate::interaction_model::messages::ib; -use crate::tlv::{FromTLV, TLVElement, TLVWriter, TagType, ToTLV}; -use crate::{error::*, interaction_model::command::CommandReq}; -use log::{error, info}; -use num_derive::FromPrimitive; -use std::sync::Arc; +use crate::interaction_model::core::Transaction; +use crate::tlv::{FromTLV, TLVElement, TLVWriter, TagType, ToTLV, UtfStr}; +use crate::transport::session::Session; +use crate::utils::rand::Rand; +use crate::{attribute_enum, cmd_enter}; +use crate::{command_enum, error::*}; +use log::info; +use strum::{EnumDiscriminants, FromRepr}; #[derive(Clone, Copy)] #[allow(dead_code)] @@ -38,65 +41,80 @@ enum CommissioningError { pub const ID: u32 = 0x0030; -#[derive(FromPrimitive)] +#[derive(FromRepr, EnumDiscriminants)] +#[repr(u16)] pub enum Attributes { - BreadCrumb = 0, - BasicCommissioningInfo = 1, - RegConfig = 2, - LocationCapability = 3, + BreadCrumb(AttrType) = 0, + BasicCommissioningInfo(()) = 1, + RegConfig(AttrType) = 2, + LocationCapability(AttrType) = 3, } -#[derive(FromPrimitive)] +attribute_enum!(Attributes); + +#[derive(FromRepr)] +#[repr(u32)] pub enum Commands { ArmFailsafe = 0x00, - ArmFailsafeResp = 0x01, SetRegulatoryConfig = 0x02, - SetRegulatoryConfigResp = 0x03, CommissioningComplete = 0x04, - CommissioningCompleteResp = 0x05, } -pub enum RegLocationType { - Indoor = 0, - Outdoor = 1, - IndoorOutdoor = 2, -} +command_enum!(Commands); -fn attr_bread_crumb_new(bread_crumb: u64) -> Attribute { - Attribute::new( - Attributes::BreadCrumb as u16, - AttrValue::Uint64(bread_crumb), - Access::READ | Access::WRITE | Access::NEED_ADMIN, - Quality::NONE, - ) +#[repr(u16)] +pub enum RespCommands { + ArmFailsafeResp = 0x01, + SetRegulatoryConfigResp = 0x03, + CommissioningCompleteResp = 0x05, } -fn attr_reg_config_new(reg_config: RegLocationType) -> Attribute { - Attribute::new( - Attributes::RegConfig as u16, - AttrValue::Uint8(reg_config as u8), - Access::RV, - Quality::NONE, - ) +#[derive(FromTLV, ToTLV)] +#[tlvargs(lifetime = "'a")] +struct CommonResponse<'a> { + error_code: u8, + debug_txt: UtfStr<'a>, } -fn attr_location_capability_new(reg_config: RegLocationType) -> Attribute { - Attribute::new( - Attributes::LocationCapability as u16, - AttrValue::Uint8(reg_config as u8), - Access::RV, - Quality::FIXED, - ) +pub enum RegLocationType { + Indoor = 0, + Outdoor = 1, + IndoorOutdoor = 2, } -fn attr_comm_info_new() -> Attribute { - Attribute::new( - Attributes::BasicCommissioningInfo as u16, - AttrValue::Custom, - Access::RV, - Quality::FIXED, - ) -} +pub const CLUSTER: Cluster<'static> = Cluster { + id: ID as _, + feature_map: 0, + attributes: &[ + FEATURE_MAP, + ATTRIBUTE_LIST, + Attribute::new( + AttributesDiscriminants::BreadCrumb as u16, + Access::READ.union(Access::WRITE).union(Access::NEED_ADMIN), + Quality::NONE, + ), + Attribute::new( + AttributesDiscriminants::RegConfig as u16, + Access::RV, + Quality::NONE, + ), + Attribute::new( + AttributesDiscriminants::LocationCapability as u16, + Access::RV, + Quality::FIXED, + ), + Attribute::new( + AttributesDiscriminants::BasicCommissioningInfo as u16, + Access::RV, + Quality::FIXED, + ), + ], + commands: &[ + Commands::ArmFailsafe as _, + Commands::SetRegulatoryConfig as _, + Commands::CommissioningComplete as _, + ], +}; #[derive(FromTLV, ToTLV)] struct FailSafeParams { @@ -105,143 +123,134 @@ struct FailSafeParams { } pub struct GenCommCluster { + data_ver: Dataver, expiry_len: u16, - failsafe: Arc, - base: Cluster, + failsafe: RefCell, } -impl ClusterType for GenCommCluster { - fn base(&self) -> &Cluster { - &self.base +impl GenCommCluster { + pub fn new(rand: Rand) -> Self { + Self { + data_ver: Dataver::new(rand), + failsafe: RefCell::new(FailSafe::new()), + // TODO: Arch-Specific + expiry_len: 120, + } } - fn base_mut(&mut self) -> &mut Cluster { - &mut self.base + + pub fn failsafe(&self) -> &RefCell { + &self.failsafe } - fn read_custom_attribute(&self, encoder: &mut dyn Encoder, attr: &AttrDetails) { - match num::FromPrimitive::from_u16(attr.attr_id) { - Some(Attributes::BasicCommissioningInfo) => { - encoder.encode(EncodeValue::Closure(&|tag, tw| { - let _ = tw.start_struct(tag); - let _ = tw.u16(TagType::Context(0), self.expiry_len); - let _ = tw.end_container(); - })) - } - _ => { - error!("Unsupported Attribute: this shouldn't happen"); + pub fn read(&self, attr: &AttrDetails, encoder: AttrDataEncoder) -> Result<(), Error> { + if let Some(mut writer) = encoder.with_dataver(self.data_ver.get())? { + if attr.is_system() { + CLUSTER.read(attr.attr_id, writer) + } else { + match attr.attr_id.try_into()? { + Attributes::BreadCrumb(codec) => codec.encode(writer, 0), + // TODO: Arch-Specific + Attributes::RegConfig(codec) => { + codec.encode(writer, RegLocationType::IndoorOutdoor as _) + } + // TODO: Arch-Specific + Attributes::LocationCapability(codec) => { + codec.encode(writer, RegLocationType::IndoorOutdoor as _) + } + Attributes::BasicCommissioningInfo(_) => { + writer.start_struct(AttrDataWriter::TAG)?; + writer.u16(TagType::Context(0), self.expiry_len)?; + writer.end_container()?; + + writer.complete() + } + } } + } else { + Ok(()) } } - fn handle_command(&mut self, cmd_req: &mut CommandReq) -> Result<(), IMStatusCode> { - let cmd = cmd_req - .cmd - .path - .leaf - .map(num::FromPrimitive::from_u32) - .ok_or(IMStatusCode::UnsupportedCommand)? - .ok_or(IMStatusCode::UnsupportedCommand)?; - match cmd { - Commands::ArmFailsafe => self.handle_command_armfailsafe(cmd_req), - Commands::SetRegulatoryConfig => self.handle_command_setregulatoryconfig(cmd_req), - Commands::CommissioningComplete => self.handle_command_commissioningcomplete(cmd_req), - _ => Err(IMStatusCode::UnsupportedCommand), + pub fn invoke( + &mut self, + session: &mut Session, + cmd: &CmdDetails, + data: &TLVElement, + encoder: CmdDataEncoder, + ) -> Result<(), Error> { + match cmd.cmd_id.try_into()? { + Commands::ArmFailsafe => self.handle_command_armfailsafe(session, data, encoder)?, + Commands::SetRegulatoryConfig => { + self.handle_command_setregulatoryconfig(data, encoder)? + } + Commands::CommissioningComplete => { + self.handle_command_commissioningcomplete(session, encoder)?; + } } - } -} - -impl GenCommCluster { - pub fn new() -> Result, Error> { - let failsafe = Arc::new(FailSafe::new()); - let mut c = Box::new(GenCommCluster { - // TODO: Arch-Specific - expiry_len: 120, - failsafe, - base: Cluster::new(ID)?, - }); - c.base.add_attribute(attr_bread_crumb_new(0))?; - // TODO: Arch-Specific - c.base - .add_attribute(attr_reg_config_new(RegLocationType::IndoorOutdoor))?; - // TODO: Arch-Specific - c.base - .add_attribute(attr_location_capability_new(RegLocationType::IndoorOutdoor))?; - c.base.add_attribute(attr_comm_info_new())?; - - Ok(c) - } + self.data_ver.changed(); - pub fn failsafe(&self) -> Arc { - self.failsafe.clone() + Ok(()) } - fn handle_command_armfailsafe(&mut self, cmd_req: &mut CommandReq) -> Result<(), IMStatusCode> { + fn handle_command_armfailsafe( + &mut self, + session: &Session, + data: &TLVElement, + encoder: CmdDataEncoder, + ) -> Result<(), Error> { cmd_enter!("ARM Fail Safe"); - let p = FailSafeParams::from_tlv(&cmd_req.data)?; - let mut status = CommissioningError::Ok as u8; + let p = FailSafeParams::from_tlv(data)?; - if self - .failsafe - .arm(p.expiry_len, cmd_req.trans.session.get_session_mode()) - .is_err() - { - status = CommissioningError::ErrBusyWithOtherAdmin as u8; - } + self.failsafe + .borrow_mut() + .arm(p.expiry_len, session.get_session_mode()) + .map_err(|e| e.remap(|_| true, Error::Busy))?; let cmd_data = CommonResponse { - error_code: status, - debug_txt: "".to_owned(), + error_code: CommissioningError::ErrBusyWithOtherAdmin as u8, + debug_txt: UtfStr::new(b""), }; - let resp = ib::InvResp::cmd_new( - 0, - ID, - Commands::ArmFailsafeResp as u16, - EncodeValue::Value(&cmd_data), - ); - let _ = resp.to_tlv(cmd_req.resp, TagType::Anonymous); - cmd_req.trans.complete(); - Ok(()) + + encoder + .with_command(RespCommands::ArmFailsafeResp as _)? + .set(&cmd_data) } fn handle_command_setregulatoryconfig( &mut self, - cmd_req: &mut CommandReq, - ) -> Result<(), IMStatusCode> { + data: &TLVElement, + encoder: CmdDataEncoder, + ) -> Result<(), Error> { cmd_enter!("Set Regulatory Config"); - let country_code = cmd_req - .data + let country_code = data .find_tag(1) - .map_err(|_| IMStatusCode::InvalidCommand)? + .map_err(|_| Error::InvalidCommand)? .slice() - .map_err(|_| IMStatusCode::InvalidCommand)?; + .map_err(|_| Error::InvalidCommand)?; info!("Received country code: {:?}", country_code); let cmd_data = CommonResponse { error_code: 0, - debug_txt: "".to_owned(), + debug_txt: UtfStr::new(b""), }; - let resp = ib::InvResp::cmd_new( - 0, - ID, - Commands::SetRegulatoryConfigResp as u16, - EncodeValue::Value(&cmd_data), - ); - let _ = resp.to_tlv(cmd_req.resp, TagType::Anonymous); - cmd_req.trans.complete(); - Ok(()) + + encoder + .with_command(RespCommands::SetRegulatoryConfigResp as _)? + .set(&cmd_data) } fn handle_command_commissioningcomplete( &mut self, - cmd_req: &mut CommandReq, - ) -> Result<(), IMStatusCode> { + session: &Session, + encoder: CmdDataEncoder, + ) -> Result<(), Error> { cmd_enter!("Commissioning Complete"); let mut status: u8 = CommissioningError::Ok as u8; // Has to be a Case Session - if cmd_req.trans.session.get_local_fabric_idx().is_none() { + if session.get_local_fabric_idx().is_none() { status = CommissioningError::ErrInvalidAuth as u8; } @@ -249,7 +258,8 @@ impl GenCommCluster { // scope that is for this session if self .failsafe - .disarm(cmd_req.trans.session.get_session_mode()) + .borrow_mut() + .disarm(session.get_session_mode()) .is_err() { status = CommissioningError::ErrInvalidAuth as u8; @@ -257,22 +267,35 @@ impl GenCommCluster { let cmd_data = CommonResponse { error_code: status, - debug_txt: "".to_owned(), + debug_txt: UtfStr::new(b""), }; - let resp = ib::InvResp::cmd_new( - 0, - ID, - Commands::CommissioningCompleteResp as u16, - EncodeValue::Value(&cmd_data), - ); - let _ = resp.to_tlv(cmd_req.resp, TagType::Anonymous); - cmd_req.trans.complete(); - Ok(()) + + encoder + .with_command(RespCommands::CommissioningCompleteResp as _)? + .set(&cmd_data) } } -#[derive(FromTLV, ToTLV)] -struct CommonResponse { - error_code: u8, - debug_txt: String, +impl Handler for GenCommCluster { + fn read(&self, attr: &AttrDetails, encoder: AttrDataEncoder) -> Result<(), Error> { + GenCommCluster::read(self, attr, encoder) + } + + fn invoke( + &mut self, + transaction: &mut Transaction, + cmd: &CmdDetails, + data: &TLVElement, + encoder: CmdDataEncoder, + ) -> Result<(), Error> { + GenCommCluster::invoke(self, transaction.session_mut(), cmd, data, encoder) + } +} + +impl NonBlockingHandler for GenCommCluster {} + +impl ChangeNotifier<()> for GenCommCluster { + fn consume_change(&mut self) -> Option<()> { + self.data_ver.consume_change(()) + } } diff --git a/matter/src/data_model/sdm/noc.rs b/matter/src/data_model/sdm/noc.rs index 7c85f59b..0258f3a5 100644 --- a/matter/src/data_model/sdm/noc.rs +++ b/matter/src/data_model/sdm/noc.rs @@ -15,24 +15,25 @@ * limitations under the License. */ -use std::sync::Arc; -use std::time::{SystemTime, UNIX_EPOCH}; +use core::cell::RefCell; +use core::convert::TryInto; use crate::acl::{AclEntry, AclMgr, AuthMode}; use crate::cert::Cert; -use crate::crypto::{self, CryptoKeyPair, KeyPair}; +use crate::crypto::{self, KeyPair}; use crate::data_model::objects::*; use crate::data_model::sdm::dev_att; use crate::fabric::{Fabric, FabricMgr, MAX_SUPPORTED_FABRICS}; -use crate::interaction_model::command::CommandReq; -use crate::interaction_model::core::IMStatusCode; -use crate::interaction_model::messages::ib; +use crate::interaction_model::core::Transaction; +use crate::mdns::MdnsMgr; use crate::tlv::{FromTLV, OctetStr, TLVElement, TLVWriter, TagType, ToTLV, UtfStr}; -use crate::transport::session::SessionMode; +use crate::transport::session::{Session, SessionMode}; +use crate::utils::epoch::Epoch; +use crate::utils::rand::Rand; use crate::utils::writebuf::WriteBuf; -use crate::{cmd_enter, error::*, secure_channel}; +use crate::{attribute_enum, cmd_enter, command_enum, error::*}; use log::{error, info}; -use num_derive::FromPrimitive; +use strum::{EnumDiscriminants, FromRepr}; use super::dev_att::{DataType, DevAttDataFetcher}; use super::failsafe::FailSafe; @@ -56,6 +57,23 @@ enum NocStatus { InvalidFabricIndex = 11, } +enum NocError { + Status(NocStatus), + Error(Error), +} + +impl From for NocError { + fn from(value: NocStatus) -> Self { + Self::Status(value) + } +} + +impl From for NocError { + fn from(value: Error) -> Self { + Self::Error(value) + } +} + // Some placeholder value for now const MAX_CERT_DECLARATION_LEN: usize = 600; // Some placeholder value for now @@ -65,39 +83,80 @@ const RESP_MAX: usize = 900; pub const ID: u32 = 0x003E; -#[derive(FromPrimitive)] +#[derive(FromRepr)] +#[repr(u32)] pub enum Commands { AttReq = 0x00, - AttReqResp = 0x01, CertChainReq = 0x02, - CertChainResp = 0x03, CSRReq = 0x04, - CSRResp = 0x05, AddNOC = 0x06, - NOCResp = 0x08, UpdateFabricLabel = 0x09, RemoveFabric = 0x0a, AddTrustedRootCert = 0x0b, } -#[derive(FromPrimitive)] +command_enum!(Commands); + +#[repr(u16)] +pub enum RespCommands { + AttReqResp = 0x01, + CertChainResp = 0x03, + CSRResp = 0x05, + NOCResp = 0x08, +} + +#[derive(FromRepr, EnumDiscriminants)] +#[repr(u16)] pub enum Attributes { NOCs = 0, - Fabrics = 1, - SupportedFabrics = 2, - CommissionedFabrics = 3, + Fabrics(()) = 1, + SupportedFabrics(AttrType) = 2, + CommissionedFabrics(AttrType) = 3, TrustedRootCerts = 4, - CurrentFabricIndex = 5, + CurrentFabricIndex(AttrType) = 5, } -pub struct NocCluster { - base: Cluster, - dev_att: Box, - fabric_mgr: Arc, - acl_mgr: Arc, - failsafe: Arc, -} -struct NocData { +attribute_enum!(Attributes); + +pub const CLUSTER: Cluster<'static> = Cluster { + id: ID as _, + feature_map: 0, + attributes: &[ + FEATURE_MAP, + ATTRIBUTE_LIST, + Attribute::new( + AttributesDiscriminants::CurrentFabricIndex as u16, + Access::RV, + Quality::NONE, + ), + Attribute::new( + AttributesDiscriminants::Fabrics as u16, + Access::RV.union(Access::FAB_SCOPED), + Quality::NONE, + ), + Attribute::new( + AttributesDiscriminants::SupportedFabrics as u16, + Access::RV, + Quality::FIXED, + ), + Attribute::new( + AttributesDiscriminants::CommissionedFabrics as u16, + Access::RV, + Quality::NONE, + ), + ], + commands: &[ + Commands::AttReq as _, + Commands::CertChainReq as _, + Commands::CSRReq as _, + Commands::AddNOC as _, + Commands::UpdateFabricLabel as _, + Commands::RemoveFabric as _, + Commands::AddTrustedRootCert as _, + ], +}; + +pub struct NocData { pub key_pair: KeyPair, pub root_ca: Cert, } @@ -111,82 +170,186 @@ impl NocData { } } -impl NocCluster { +#[derive(ToTLV)] +struct CertChainResp<'a> { + cert: OctetStr<'a>, +} + +#[derive(ToTLV)] +struct NocResp<'a> { + status_code: u8, + fab_idx: u8, + debug_txt: UtfStr<'a>, +} + +#[derive(FromTLV)] +#[tlvargs(lifetime = "'a")] +struct AddNocReq<'a> { + noc_value: OctetStr<'a>, + icac_value: OctetStr<'a>, + ipk_value: OctetStr<'a>, + case_admin_subject: u64, + vendor_id: u16, +} + +#[derive(FromTLV)] +#[tlvargs(lifetime = "'a")] +struct CommonReq<'a> { + str: OctetStr<'a>, +} + +#[derive(FromTLV)] +#[tlvargs(lifetime = "'a")] +struct UpdateFabricLabelReq<'a> { + label: UtfStr<'a>, +} + +#[derive(FromTLV)] +struct CertChainReq { + cert_type: u8, +} + +#[derive(FromTLV)] +struct RemoveFabricReq { + fab_idx: u8, +} + +pub struct NocCluster<'a> { + data_ver: Dataver, + epoch: Epoch, + dev_att: &'a dyn DevAttDataFetcher, + fabric_mgr: &'a RefCell, + acl_mgr: &'a RefCell, + failsafe: &'a RefCell, + mdns_mgr: &'a RefCell>, +} + +impl<'a> NocCluster<'a> { pub fn new( - dev_att: Box, - fabric_mgr: Arc, - acl_mgr: Arc, - failsafe: Arc, - ) -> Result, Error> { - let mut c = Box::new(Self { + dev_att: &'a dyn DevAttDataFetcher, + fabric_mgr: &'a RefCell, + acl_mgr: &'a RefCell, + failsafe: &'a RefCell, + mdns_mgr: &'a RefCell>, + epoch: Epoch, + rand: Rand, + ) -> Self { + Self { + data_ver: Dataver::new(rand), + epoch, dev_att, fabric_mgr, acl_mgr, failsafe, - base: Cluster::new(ID)?, - }); - let attrs = [ - Attribute::new( - Attributes::CurrentFabricIndex as u16, - AttrValue::Custom, - Access::RV, - Quality::NONE, - ), - Attribute::new( - Attributes::Fabrics as u16, - AttrValue::Custom, - Access::RV | Access::FAB_SCOPED, - Quality::NONE, - ), - Attribute::new( - Attributes::SupportedFabrics as u16, - AttrValue::Uint8(MAX_SUPPORTED_FABRICS as u8), - Access::RV, - Quality::FIXED, - ), - Attribute::new( - Attributes::CommissionedFabrics as u16, - AttrValue::Custom, - Access::RV, - Quality::NONE, - ), - ]; - c.base.add_attributes(&attrs[..])?; - Ok(c) + mdns_mgr, + } + } + + pub fn read(&self, attr: &AttrDetails, encoder: AttrDataEncoder) -> Result<(), Error> { + if let Some(mut writer) = encoder.with_dataver(self.data_ver.get())? { + if attr.is_system() { + CLUSTER.read(attr.attr_id, writer) + } else { + match attr.attr_id.try_into()? { + Attributes::SupportedFabrics(codec) => { + codec.encode(writer, MAX_SUPPORTED_FABRICS as _) + } + Attributes::CurrentFabricIndex(codec) => codec.encode(writer, attr.fab_idx), + Attributes::Fabrics(_) => { + writer.start_array(AttrDataWriter::TAG)?; + self.fabric_mgr.borrow().for_each(|entry, fab_idx| { + if !attr.fab_filter || attr.fab_idx == fab_idx { + entry + .get_fabric_desc(fab_idx) + .to_tlv(&mut writer, TagType::Anonymous)?; + } + + Ok(()) + })?; + writer.end_container()?; + + writer.complete() + } + Attributes::CommissionedFabrics(codec) => { + codec.encode(writer, self.fabric_mgr.borrow().used_count() as _) + } + _ => { + error!("Attribute not supported: this shouldn't happen"); + Err(Error::AttributeNotFound) + } + } + } + } else { + Ok(()) + } + } + + pub fn invoke( + &mut self, + transaction: &mut Transaction, + cmd: &CmdDetails, + data: &TLVElement, + encoder: CmdDataEncoder, + ) -> Result<(), Error> { + match cmd.cmd_id.try_into()? { + Commands::AddNOC => { + self.handle_command_addnoc(transaction.session_mut(), data, encoder)? + } + Commands::CSRReq => { + self.handle_command_csrrequest(transaction.session_mut(), data, encoder)? + } + Commands::AddTrustedRootCert => { + self.handle_command_addtrustedrootcert(transaction.session_mut(), data)? + } + Commands::AttReq => { + self.handle_command_attrequest(transaction.session(), data, encoder)? + } + Commands::CertChainReq => self.handle_command_certchainrequest(data, encoder)?, + Commands::UpdateFabricLabel => { + self.handle_command_updatefablabel(transaction.session(), data, encoder)?; + } + Commands::RemoveFabric => self.handle_command_rmfabric(transaction, data, encoder)?, + } + + self.data_ver.changed(); + + Ok(()) } fn add_acl(&self, fab_idx: u8, admin_subject: u64) -> Result<(), Error> { let mut acl = AclEntry::new(fab_idx, Privilege::ADMIN, AuthMode::Case); acl.add_subject(admin_subject)?; - self.acl_mgr.add(acl) + self.acl_mgr.borrow_mut().add(acl) } - fn _handle_command_addnoc(&mut self, cmd_req: &mut CommandReq) -> Result<(), NocStatus> { - let noc_data = cmd_req - .trans - .session - .take_data::() - .ok_or(NocStatus::MissingCsr)?; + fn _handle_command_addnoc( + &mut self, + session: &mut Session, + data: &TLVElement, + ) -> Result { + let noc_data = session.take_noc_data().ok_or(NocStatus::MissingCsr)?; if !self .failsafe + .borrow_mut() .allow_noc_change() .map_err(|_| NocStatus::InsufficientPrivlege)? { error!("AddNOC not allowed by Fail Safe"); - return Err(NocStatus::InsufficientPrivlege); + Err(NocStatus::InsufficientPrivlege)?; } - // This command's processing may take longer, send a stand alone ACK to the peer to avoid any retranmissions - let ack_send = secure_channel::common::send_mrp_standalone_ack( - cmd_req.trans.exch, - cmd_req.trans.session, - ); - if ack_send.is_err() { - error!("Error sending Standalone ACK, falling back to piggybacked ACK"); - } + // TODO + // // This command's processing may take longer, send a stand alone ACK to the peer to avoid any retranmissions + // let ack_send = secure_channel::common::send_mrp_standalone_ack( + // trans.exch, + // trans.session, + // ); + // if ack_send.is_err() { + // error!("Error sending Standalone ACK, falling back to piggybacked ACK"); + // } - let r = AddNocReq::from_tlv(&cmd_req.data).map_err(|_| NocStatus::InvalidNOC)?; + let r = AddNocReq::from_tlv(data).map_err(|_| NocStatus::InvalidNOC)?; let noc_value = Cert::new(r.noc_value.0).map_err(|_| NocStatus::InvalidNOC)?; info!("Received NOC as: {}", noc_value); @@ -197,6 +360,7 @@ impl NocCluster { } else { None }; + let fabric = Fabric::new( noc_data.key_pair, noc_data.root_ca, @@ -204,298 +368,265 @@ impl NocCluster { noc_value, r.ipk_value.0, r.vendor_id, + "", ) .map_err(|_| NocStatus::TableFull)?; let fab_idx = self .fabric_mgr - .add(fabric) + .borrow_mut() + .add(fabric, &mut self.mdns_mgr.borrow_mut()) .map_err(|_| NocStatus::TableFull)?; - if self.add_acl(fab_idx, r.case_admin_subject).is_err() { - error!("Failed to add ACL, what to do?"); - } + self.add_acl(fab_idx, r.case_admin_subject)?; - if self.failsafe.record_add_noc(fab_idx).is_err() { - error!("Failed to record NoC in the FailSafe, what to do?"); - } - NocCluster::create_nocresponse(cmd_req.resp, NocStatus::Ok, fab_idx, "".to_owned()); - cmd_req.trans.complete(); - Ok(()) + self.failsafe.borrow_mut().record_add_noc(fab_idx)?; + + Ok(fab_idx) } fn create_nocresponse( - tw: &mut TLVWriter, + encoder: CmdDataEncoder, status_code: NocStatus, fab_idx: u8, - debug_txt: String, - ) { + debug_txt: &str, + ) -> Result<(), Error> { let cmd_data = NocResp { status_code: status_code as u8, fab_idx, - debug_txt, + debug_txt: UtfStr::new(debug_txt.as_bytes()), }; - let invoke_resp = ib::InvResp::cmd_new( - 0, - ID, - Commands::NOCResp as u16, - EncodeValue::Value(&cmd_data), - ); - let _ = invoke_resp.to_tlv(tw, TagType::Anonymous); + + encoder + .with_command(RespCommands::NOCResp as _)? + .set(&cmd_data) } fn handle_command_updatefablabel( &mut self, - cmd_req: &mut CommandReq, - ) -> Result<(), IMStatusCode> { + session: &Session, + data: &TLVElement, + encoder: CmdDataEncoder, + ) -> Result<(), Error> { cmd_enter!("Update Fabric Label"); - let req = UpdateFabricLabelReq::from_tlv(&cmd_req.data) - .map_err(|_| IMStatusCode::InvalidDataType)?; - let label = req - .label - .to_string() - .map_err(|_| IMStatusCode::InvalidDataType)?; - - let (result, fab_idx) = - if let SessionMode::Case(c) = cmd_req.trans.session.get_session_mode() { - if self.fabric_mgr.set_label(c.fab_idx, label).is_err() { - (NocStatus::LabelConflict, c.fab_idx) - } else { - (NocStatus::Ok, c.fab_idx) - } + let req = UpdateFabricLabelReq::from_tlv(data).map_err(Error::map_invalid_data_type)?; + let (result, fab_idx) = if let SessionMode::Case(c) = session.get_session_mode() { + if self + .fabric_mgr + .borrow_mut() + .set_label( + c.fab_idx, + req.label.as_str().map_err(Error::map_invalid_data_type)?, + ) + .is_err() + { + (NocStatus::LabelConflict, c.fab_idx) } else { - // Update Fabric Label not allowed - (NocStatus::InvalidFabricIndex, 0) - }; - NocCluster::create_nocresponse(cmd_req.resp, result, fab_idx, "".to_string()); - cmd_req.trans.complete(); - Ok(()) + (NocStatus::Ok, c.fab_idx) + } + } else { + // Update Fabric Label not allowed + (NocStatus::InvalidFabricIndex, 0) + }; + + Self::create_nocresponse(encoder, result, fab_idx, "") } - fn handle_command_rmfabric(&mut self, cmd_req: &mut CommandReq) -> Result<(), IMStatusCode> { + fn handle_command_rmfabric( + &mut self, + transaction: &mut Transaction, + data: &TLVElement, + encoder: CmdDataEncoder, + ) -> Result<(), Error> { cmd_enter!("Remove Fabric"); - let req = - RemoveFabricReq::from_tlv(&cmd_req.data).map_err(|_| IMStatusCode::InvalidCommand)?; - if self.fabric_mgr.remove(req.fab_idx).is_ok() { - let _ = self.acl_mgr.delete_for_fabric(req.fab_idx); - cmd_req.trans.terminate(); + let req = RemoveFabricReq::from_tlv(data).map_err(Error::map_invalid_data_type)?; + if self + .fabric_mgr + .borrow_mut() + .remove(req.fab_idx, &mut self.mdns_mgr.borrow_mut()) + .is_ok() + { + let _ = self.acl_mgr.borrow_mut().delete_for_fabric(req.fab_idx); + transaction.terminate(); + Ok(()) } else { - NocCluster::create_nocresponse( - cmd_req.resp, - NocStatus::InvalidFabricIndex, - req.fab_idx, - "".to_string(), - ); + Self::create_nocresponse(encoder, NocStatus::InvalidFabricIndex, req.fab_idx, "") } - Ok(()) } - fn handle_command_addnoc(&mut self, cmd_req: &mut CommandReq) -> Result<(), IMStatusCode> { + fn handle_command_addnoc( + &mut self, + session: &mut Session, + data: &TLVElement, + encoder: CmdDataEncoder, + ) -> Result<(), Error> { cmd_enter!("AddNOC"); - if let Err(e) = self._handle_command_addnoc(cmd_req) { - //TODO: Fab-idx 0? - NocCluster::create_nocresponse(cmd_req.resp, e, 0, "".to_owned()); - cmd_req.trans.complete(); - } - Ok(()) + + let (status, fab_idx) = match self._handle_command_addnoc(session, data) { + Ok(fab_idx) => (NocStatus::Ok, fab_idx), + Err(NocError::Status(status)) => (status, 0), + Err(NocError::Error(error)) => Err(error)?, + }; + + Self::create_nocresponse(encoder, status, fab_idx, "") } - fn handle_command_attrequest(&mut self, cmd_req: &mut CommandReq) -> Result<(), IMStatusCode> { + fn handle_command_attrequest( + &mut self, + session: &Session, + data: &TLVElement, + encoder: CmdDataEncoder, + ) -> Result<(), Error> { cmd_enter!("AttestationRequest"); - let req = CommonReq::from_tlv(&cmd_req.data).map_err(|_| IMStatusCode::InvalidCommand)?; + let req = CommonReq::from_tlv(data).map_err(Error::map_invalid_command)?; info!("Received Attestation Nonce:{:?}", req.str); let mut attest_challenge = [0u8; crypto::SYMM_KEY_LEN_BYTES]; - attest_challenge.copy_from_slice(cmd_req.trans.session.get_att_challenge()); - - let cmd_data = |tag: TagType, t: &mut TLVWriter| { - let mut buf: [u8; RESP_MAX] = [0; RESP_MAX]; - let mut attest_element = WriteBuf::new(&mut buf, RESP_MAX); - let _ = t.start_struct(tag); - let _ = - add_attestation_element(self.dev_att.as_ref(), req.str.0, &mut attest_element, t); - let _ = add_attestation_signature( - self.dev_att.as_ref(), - &mut attest_element, - &attest_challenge, - t, - ); - let _ = t.end_container(); - }; - let resp = ib::InvResp::cmd_new( - 0, - ID, - Commands::AttReqResp as u16, - EncodeValue::Closure(&cmd_data), - ); - let _ = resp.to_tlv(cmd_req.resp, TagType::Anonymous); - cmd_req.trans.complete(); - Ok(()) + attest_challenge.copy_from_slice(session.get_att_challenge()); + + let mut writer = encoder.with_command(RespCommands::AttReqResp as _)?; + + let mut buf: [u8; RESP_MAX] = [0; RESP_MAX]; + let mut attest_element = WriteBuf::new(&mut buf); + writer.start_struct(CmdDataWriter::TAG)?; + add_attestation_element( + self.epoch, + self.dev_att, + req.str.0, + &mut attest_element, + &mut writer, + )?; + add_attestation_signature( + self.dev_att, + &mut attest_element, + &attest_challenge, + &mut writer, + )?; + writer.end_container()?; + + writer.complete() } fn handle_command_certchainrequest( &mut self, - cmd_req: &mut CommandReq, - ) -> Result<(), IMStatusCode> { + data: &TLVElement, + encoder: CmdDataEncoder, + ) -> Result<(), Error> { cmd_enter!("CertChainRequest"); - info!("Received data: {}", cmd_req.data); - let cert_type = - get_certchainrequest_params(&cmd_req.data).map_err(|_| IMStatusCode::InvalidCommand)?; + info!("Received data: {}", data); + let cert_type = get_certchainrequest_params(data).map_err(Error::map_invalid_command)?; let mut buf: [u8; RESP_MAX] = [0; RESP_MAX]; - let len = self - .dev_att - .get_devatt_data(cert_type, &mut buf) - .map_err(|_| IMStatusCode::Failure)?; + let len = self.dev_att.get_devatt_data(cert_type, &mut buf)?; let buf = &buf[0..len]; let cmd_data = CertChainResp { cert: OctetStr::new(buf), }; - let resp = ib::InvResp::cmd_new( - 0, - ID, - Commands::CertChainResp as u16, - EncodeValue::Value(&cmd_data), - ); - let _ = resp.to_tlv(cmd_req.resp, TagType::Anonymous); - cmd_req.trans.complete(); - Ok(()) + + encoder + .with_command(RespCommands::CertChainResp as _)? + .set(&cmd_data) } - fn handle_command_csrrequest(&mut self, cmd_req: &mut CommandReq) -> Result<(), IMStatusCode> { + fn handle_command_csrrequest( + &mut self, + session: &mut Session, + data: &TLVElement, + encoder: CmdDataEncoder, + ) -> Result<(), Error> { cmd_enter!("CSRRequest"); - let req = CommonReq::from_tlv(&cmd_req.data).map_err(|_| IMStatusCode::InvalidCommand)?; + let req = CommonReq::from_tlv(data).map_err(Error::map_invalid_command)?; info!("Received CSR Nonce:{:?}", req.str); - if !self.failsafe.is_armed() { - return Err(IMStatusCode::UnsupportedAccess); + if !self.failsafe.borrow().is_armed() { + return Err(Error::UnsupportedAccess); } - let noc_keypair = KeyPair::new().map_err(|_| IMStatusCode::Failure)?; + let noc_keypair = KeyPair::new()?; let mut attest_challenge = [0u8; crypto::SYMM_KEY_LEN_BYTES]; - attest_challenge.copy_from_slice(cmd_req.trans.session.get_att_challenge()); - - let cmd_data = |tag: TagType, t: &mut TLVWriter| { - let mut buf: [u8; RESP_MAX] = [0; RESP_MAX]; - let mut nocsr_element = WriteBuf::new(&mut buf, RESP_MAX); - let _ = t.start_struct(tag); - let _ = add_nocsrelement(&noc_keypair, req.str.0, &mut nocsr_element, t); - let _ = add_attestation_signature( - self.dev_att.as_ref(), - &mut nocsr_element, - &attest_challenge, - t, - ); - let _ = t.end_container(); - }; - let resp = ib::InvResp::cmd_new( - 0, - ID, - Commands::CSRResp as u16, - EncodeValue::Closure(&cmd_data), - ); - - let _ = resp.to_tlv(cmd_req.resp, TagType::Anonymous); - let noc_data = Box::new(NocData::new(noc_keypair)); + attest_challenge.copy_from_slice(session.get_att_challenge()); + + let mut writer = encoder.with_command(RespCommands::CSRResp as _)?; + + let mut buf: [u8; RESP_MAX] = [0; RESP_MAX]; + let mut nocsr_element = WriteBuf::new(&mut buf); + writer.start_struct(CmdDataWriter::TAG)?; + add_nocsrelement(&noc_keypair, req.str.0, &mut nocsr_element, &mut writer)?; + add_attestation_signature( + self.dev_att, + &mut nocsr_element, + &attest_challenge, + &mut writer, + )?; + writer.end_container()?; + + writer.complete()?; + + let noc_data = NocData::new(noc_keypair); // Store this in the session data instead of cluster data, so it gets cleared // if the session goes away for some reason - cmd_req.trans.session.set_data(noc_data); - cmd_req.trans.complete(); + session.set_noc_data(noc_data); + Ok(()) } fn handle_command_addtrustedrootcert( &mut self, - cmd_req: &mut CommandReq, - ) -> Result<(), IMStatusCode> { + session: &mut Session, + data: &TLVElement, + ) -> Result<(), Error> { cmd_enter!("AddTrustedRootCert"); - if !self.failsafe.is_armed() { - return Err(IMStatusCode::UnsupportedAccess); + if !self.failsafe.borrow().is_armed() { + return Err(Error::UnsupportedAccess); } // This may happen on CASE or PASE. For PASE, the existence of NOC Data is necessary - match cmd_req.trans.session.get_session_mode() { + match session.get_session_mode() { SessionMode::Case(_) => error!("CASE: AddTrustedRootCert handling pending"), // For a CASE Session, we just return success for now, SessionMode::Pase => { - let noc_data = cmd_req - .trans - .session - .get_data::() - .ok_or(IMStatusCode::Failure)?; - - let req = - CommonReq::from_tlv(&cmd_req.data).map_err(|_| IMStatusCode::InvalidCommand)?; + let noc_data = session.get_noc_data::().ok_or(Error::NoSession)?; + + let req = CommonReq::from_tlv(data).map_err(Error::map_invalid_command)?; info!("Received Trusted Cert:{:x?}", req.str); - noc_data.root_ca = Cert::new(req.str.0).map_err(|_| IMStatusCode::Failure)?; + noc_data.root_ca = Cert::new(req.str.0)?; } _ => (), } - cmd_req.trans.complete(); - Err(IMStatusCode::Success) + Ok(()) } } -impl ClusterType for NocCluster { - fn base(&self) -> &Cluster { - &self.base - } - fn base_mut(&mut self) -> &mut Cluster { - &mut self.base +impl<'a> Handler for NocCluster<'a> { + fn read(&self, attr: &AttrDetails, encoder: AttrDataEncoder) -> Result<(), Error> { + NocCluster::read(self, attr, encoder) } - fn handle_command(&mut self, cmd_req: &mut CommandReq) -> Result<(), IMStatusCode> { - let cmd = cmd_req - .cmd - .path - .leaf - .map(num::FromPrimitive::from_u32) - .ok_or(IMStatusCode::UnsupportedCommand)? - .ok_or(IMStatusCode::UnsupportedCommand)?; - match cmd { - Commands::AddNOC => self.handle_command_addnoc(cmd_req), - Commands::CSRReq => self.handle_command_csrrequest(cmd_req), - Commands::AddTrustedRootCert => self.handle_command_addtrustedrootcert(cmd_req), - Commands::AttReq => self.handle_command_attrequest(cmd_req), - Commands::CertChainReq => self.handle_command_certchainrequest(cmd_req), - Commands::UpdateFabricLabel => self.handle_command_updatefablabel(cmd_req), - Commands::RemoveFabric => self.handle_command_rmfabric(cmd_req), - _ => Err(IMStatusCode::UnsupportedCommand), - } + fn invoke( + &mut self, + transaction: &mut Transaction, + cmd: &CmdDetails, + data: &TLVElement, + encoder: CmdDataEncoder, + ) -> Result<(), Error> { + NocCluster::invoke(self, transaction, cmd, data, encoder) } +} - fn read_custom_attribute(&self, encoder: &mut dyn Encoder, attr: &AttrDetails) { - match num::FromPrimitive::from_u16(attr.attr_id) { - Some(Attributes::CurrentFabricIndex) => { - encoder.encode(EncodeValue::Value(&attr.fab_idx)) - } - Some(Attributes::Fabrics) => encoder.encode(EncodeValue::Closure(&|tag, tw| { - let _ = tw.start_array(tag); - let _ = self.fabric_mgr.for_each(|entry, fab_idx| { - if !attr.fab_filter || attr.fab_idx == fab_idx { - let _ = entry - .get_fabric_desc(fab_idx) - .to_tlv(tw, TagType::Anonymous); - } - }); - let _ = tw.end_container(); - })), - Some(Attributes::CommissionedFabrics) => { - let count = self.fabric_mgr.used_count() as u8; - encoder.encode(EncodeValue::Value(&count)) - } - _ => { - error!("Attribute not supported: this shouldn't happen"); - } - } +impl<'a> NonBlockingHandler for NocCluster<'a> {} + +impl<'a> ChangeNotifier<()> for NocCluster<'a> { + fn consume_change(&mut self) -> Option<()> { + self.data_ver.consume_change(()) } } fn add_attestation_element( + epoch: Epoch, dev_att: &dyn DevAttDataFetcher, att_nonce: &[u8], write_buf: &mut WriteBuf, @@ -505,7 +636,7 @@ fn add_attestation_element( let len = dev_att.get_devatt_data(dev_att::DataType::CertDeclaration, &mut cert_dec)?; let cert_dec = &cert_dec[0..len]; - let epoch = SystemTime::now().duration_since(UNIX_EPOCH)?.as_secs() as u32; + let epoch = epoch().as_secs() as u32; let mut writer = TLVWriter::new(write_buf); writer.start_struct(TagType::Anonymous)?; writer.str16(TagType::Context(1), cert_dec)?; @@ -513,8 +644,7 @@ fn add_attestation_element( writer.u32(TagType::Context(3), epoch)?; writer.end_container()?; - t.str16(TagType::Context(0), write_buf.as_borrow_slice())?; - Ok(()) + t.str16(TagType::Context(0), write_buf.as_slice()) } fn add_attestation_signature( @@ -532,7 +662,7 @@ fn add_attestation_signature( }?; attest_element.copy_from_slice(attest_challenge)?; let mut signature = [0u8; crypto::EC_SIGNATURE_LEN_BYTES]; - dac_key.sign_msg(attest_element.as_borrow_slice(), &mut signature)?; + dac_key.sign_msg(attest_element.as_slice(), &mut signature)?; resp.str8(TagType::Context(1), &signature) } @@ -550,52 +680,7 @@ fn add_nocsrelement( writer.str8(TagType::Context(2), csr_nonce)?; writer.end_container()?; - resp.str8(TagType::Context(0), write_buf.as_borrow_slice())?; - Ok(()) -} - -#[derive(ToTLV)] -struct CertChainResp<'a> { - cert: OctetStr<'a>, -} - -#[derive(ToTLV)] -struct NocResp { - status_code: u8, - fab_idx: u8, - debug_txt: String, -} - -#[derive(FromTLV)] -#[tlvargs(lifetime = "'a")] -struct AddNocReq<'a> { - noc_value: OctetStr<'a>, - icac_value: OctetStr<'a>, - ipk_value: OctetStr<'a>, - case_admin_subject: u64, - vendor_id: u16, -} - -#[derive(FromTLV)] -#[tlvargs(lifetime = "'a")] -struct CommonReq<'a> { - str: OctetStr<'a>, -} - -#[derive(FromTLV)] -#[tlvargs(lifetime = "'a")] -struct UpdateFabricLabelReq<'a> { - label: UtfStr<'a>, -} - -#[derive(FromTLV)] -struct CertChainReq { - cert_type: u8, -} - -#[derive(FromTLV)] -struct RemoveFabricReq { - fab_idx: u8, + resp.str8(TagType::Context(0), write_buf.as_slice()) } fn get_certchainrequest_params(data: &TLVElement) -> Result { diff --git a/matter/src/data_model/sdm/nw_commissioning.rs b/matter/src/data_model/sdm/nw_commissioning.rs index 753347d2..7afff7a2 100644 --- a/matter/src/data_model/sdm/nw_commissioning.rs +++ b/matter/src/data_model/sdm/nw_commissioning.rs @@ -16,38 +16,51 @@ */ use crate::{ - data_model::objects::{Cluster, ClusterType}, + data_model::objects::{ + AttrDataEncoder, AttrDetails, ChangeNotifier, Cluster, Dataver, Handler, + NonBlockingHandler, ATTRIBUTE_LIST, FEATURE_MAP, + }, error::Error, + utils::rand::Rand, }; pub const ID: u32 = 0x0031; +enum FeatureMap { + _Wifi = 0x01, + _Thread = 0x02, + Ethernet = 0x04, +} + +pub const CLUSTER: Cluster<'static> = Cluster { + id: ID as _, + feature_map: FeatureMap::Ethernet as _, + attributes: &[FEATURE_MAP, ATTRIBUTE_LIST], + commands: &[], +}; + pub struct NwCommCluster { - base: Cluster, + data_ver: Dataver, } -impl ClusterType for NwCommCluster { - fn base(&self) -> &Cluster { - &self.base - } - fn base_mut(&mut self) -> &mut Cluster { - &mut self.base +impl NwCommCluster { + pub fn new(rand: Rand) -> Self { + Self { + data_ver: Dataver::new(rand), + } } } -enum FeatureMap { - _Wifi = 0x01, - _Thread = 0x02, - Ethernet = 0x04, +impl Handler for NwCommCluster { + fn read(&self, _attr: &AttrDetails, _encoder: AttrDataEncoder) -> Result<(), Error> { + Err(Error::AttributeNotFound) + } } -impl NwCommCluster { - pub fn new() -> Result, Error> { - let mut c = Box::new(Self { - base: Cluster::new(ID)?, - }); - // TODO: Arch-Specific - c.base.set_feature_map(FeatureMap::Ethernet as u32)?; - Ok(c) +impl NonBlockingHandler for NwCommCluster {} + +impl ChangeNotifier<()> for NwCommCluster { + fn consume_change(&mut self) -> Option<()> { + self.data_ver.consume_change(()) } } diff --git a/matter/src/data_model/system_model/access_control.rs b/matter/src/data_model/system_model/access_control.rs index df1297b8..3980a434 100644 --- a/matter/src/data_model/system_model/access_control.rs +++ b/matter/src/data_model/system_model/access_control.rs @@ -15,46 +15,130 @@ * limitations under the License. */ -use std::sync::Arc; +use core::cell::RefCell; +use core::convert::TryInto; -use num_derive::FromPrimitive; +use strum::{EnumDiscriminants, FromRepr}; -use crate::acl::{self, AclEntry, AclMgr}; +use crate::acl::{AclEntry, AclMgr}; use crate::data_model::objects::*; -use crate::error::*; -use crate::interaction_model::core::IMStatusCode; use crate::interaction_model::messages::ib::{attr_list_write, ListOperation}; use crate::tlv::{FromTLV, TLVElement, TagType, ToTLV}; +use crate::utils::rand::Rand; +use crate::{attribute_enum, error::*}; use log::{error, info}; pub const ID: u32 = 0x001F; -#[derive(FromPrimitive)] +#[derive(FromRepr, EnumDiscriminants)] +#[repr(u16)] pub enum Attributes { - Acl = 0, - Extension = 1, - SubjectsPerEntry = 2, - TargetsPerEntry = 3, - EntriesPerFabric = 4, + Acl(()) = 0, + Extension(()) = 1, + SubjectsPerEntry(AttrType) = 2, + TargetsPerEntry(AttrType) = 3, + EntriesPerFabric(AttrType) = 4, } -pub struct AccessControlCluster { - base: Cluster, - acl_mgr: Arc, +attribute_enum!(Attributes); + +pub const CLUSTER: Cluster<'static> = Cluster { + id: ID, + feature_map: 0, + attributes: &[ + FEATURE_MAP, + ATTRIBUTE_LIST, + Attribute::new( + AttributesDiscriminants::Acl as u16, + Access::RWFA, + Quality::NONE, + ), + Attribute::new( + AttributesDiscriminants::Extension as u16, + Access::RWFA, + Quality::NONE, + ), + Attribute::new( + AttributesDiscriminants::SubjectsPerEntry as u16, + Access::RV, + Quality::FIXED, + ), + Attribute::new( + AttributesDiscriminants::TargetsPerEntry as u16, + Access::RV, + Quality::FIXED, + ), + Attribute::new( + AttributesDiscriminants::EntriesPerFabric as u16, + Access::RV, + Quality::FIXED, + ), + ], + commands: &[], +}; + +pub struct AccessControlCluster<'a> { + data_ver: Dataver, + acl_mgr: &'a RefCell, } -impl AccessControlCluster { - pub fn new(acl_mgr: Arc) -> Result, Error> { - let mut c = Box::new(AccessControlCluster { - base: Cluster::new(ID)?, +impl<'a> AccessControlCluster<'a> { + pub fn new(acl_mgr: &'a RefCell, rand: Rand) -> Self { + Self { + data_ver: Dataver::new(rand), acl_mgr, - }); - c.base.add_attribute(attr_acl_new())?; - c.base.add_attribute(attr_extension_new())?; - c.base.add_attribute(attr_subjects_per_entry_new())?; - c.base.add_attribute(attr_targets_per_entry_new())?; - c.base.add_attribute(attr_entries_per_fabric_new())?; - Ok(c) + } + } + + pub fn read(&self, attr: &AttrDetails, encoder: AttrDataEncoder) -> Result<(), Error> { + if let Some(mut writer) = encoder.with_dataver(self.data_ver.get())? { + if attr.is_system() { + CLUSTER.read(attr.attr_id, writer) + } else { + match attr.attr_id.try_into()? { + Attributes::Acl(_) => { + writer.start_array(AttrDataWriter::TAG)?; + self.acl_mgr.borrow().for_each_acl(|entry| { + if !attr.fab_filter || Some(attr.fab_idx) == entry.fab_idx { + entry.to_tlv(&mut writer, TagType::Anonymous)?; + } + + Ok(()) + })?; + writer.end_container()?; + + writer.complete() + } + Attributes::Extension(_) => { + // Empty for now + writer.start_array(AttrDataWriter::TAG)?; + writer.end_container()?; + + writer.complete() + } + _ => { + error!("Attribute not yet supported: this shouldn't happen"); + Err(Error::AttributeNotFound) + } + } + } + } else { + Ok(()) + } + } + + pub fn write(&mut self, attr: &AttrDetails, data: AttrData) -> Result<(), Error> { + match attr.attr_id.try_into()? { + Attributes::Acl(_) => { + attr_list_write(attr, data.with_dataver(self.data_ver.get())?, |op, data| { + self.write_acl_attr(&op, data, attr.fab_idx) + }) + } + _ => { + error!("Attribute not yet supported: this shouldn't happen"); + Err(Error::AttributeNotFound) + } + } } /// Write the ACL Attribute @@ -66,141 +150,59 @@ impl AccessControlCluster { op: &ListOperation, data: &TLVElement, fab_idx: u8, - ) -> Result<(), IMStatusCode> { + ) -> Result<(), Error> { info!("Performing ACL operation {:?}", op); - let result = match op { + match op { ListOperation::AddItem | ListOperation::EditItem(_) => { - let mut acl_entry = - AclEntry::from_tlv(data).map_err(|_| IMStatusCode::ConstraintError)?; + let mut acl_entry = AclEntry::from_tlv(data)?; info!("ACL {:?}", acl_entry); // Overwrite the fabric index with our accessing fabric index acl_entry.fab_idx = Some(fab_idx); if let ListOperation::EditItem(index) = op { - self.acl_mgr.edit(*index as u8, fab_idx, acl_entry) + self.acl_mgr + .borrow_mut() + .edit(*index as u8, fab_idx, acl_entry) } else { - self.acl_mgr.add(acl_entry) + self.acl_mgr.borrow_mut().add(acl_entry) } } - ListOperation::DeleteItem(index) => self.acl_mgr.delete(*index as u8, fab_idx), - ListOperation::DeleteList => self.acl_mgr.delete_for_fabric(fab_idx), - }; - match result { - Ok(_) => Ok(()), - Err(Error::NoSpace) => Err(IMStatusCode::ResourceExhausted), - _ => Err(IMStatusCode::ConstraintError), + ListOperation::DeleteItem(index) => { + self.acl_mgr.borrow_mut().delete(*index as u8, fab_idx) + } + ListOperation::DeleteList => self.acl_mgr.borrow_mut().delete_for_fabric(fab_idx), } } } -impl ClusterType for AccessControlCluster { - fn base(&self) -> &Cluster { - &self.base - } - fn base_mut(&mut self) -> &mut Cluster { - &mut self.base - } - - fn read_custom_attribute(&self, encoder: &mut dyn Encoder, attr: &AttrDetails) { - match num::FromPrimitive::from_u16(attr.attr_id) { - Some(Attributes::Acl) => encoder.encode(EncodeValue::Closure(&|tag, tw| { - let _ = tw.start_array(tag); - let _ = self.acl_mgr.for_each_acl(|entry| { - if !attr.fab_filter || Some(attr.fab_idx) == entry.fab_idx { - let _ = entry.to_tlv(tw, TagType::Anonymous); - } - }); - let _ = tw.end_container(); - })), - Some(Attributes::Extension) => encoder.encode(EncodeValue::Closure(&|tag, tw| { - // Empty for now - let _ = tw.start_array(tag); - let _ = tw.end_container(); - })), - _ => { - error!("Attribute not yet supported: this shouldn't happen"); - } - } +impl<'a> Handler for AccessControlCluster<'a> { + fn read(&self, attr: &AttrDetails, encoder: AttrDataEncoder) -> Result<(), Error> { + AccessControlCluster::read(self, attr, encoder) } - fn write_attribute( - &mut self, - attr: &AttrDetails, - data: &TLVElement, - ) -> Result<(), IMStatusCode> { - let result = if let Some(Attributes::Acl) = num::FromPrimitive::from_u16(attr.attr_id) { - attr_list_write(attr, data, |op, data| { - self.write_acl_attr(&op, data, attr.fab_idx) - }) - } else { - error!("Attribute not yet supported: this shouldn't happen"); - Err(IMStatusCode::NotFound) - }; - if result.is_ok() { - self.base.cluster_changed(); - } - result + fn write(&mut self, attr: &AttrDetails, data: AttrData) -> Result<(), Error> { + AccessControlCluster::write(self, attr, data) } } -fn attr_acl_new() -> Attribute { - Attribute::new( - Attributes::Acl as u16, - AttrValue::Custom, - Access::RWFA, - Quality::NONE, - ) -} - -fn attr_extension_new() -> Attribute { - Attribute::new( - Attributes::Extension as u16, - AttrValue::Custom, - Access::RWFA, - Quality::NONE, - ) -} - -fn attr_subjects_per_entry_new() -> Attribute { - Attribute::new( - Attributes::SubjectsPerEntry as u16, - AttrValue::Uint16(acl::SUBJECTS_PER_ENTRY as u16), - Access::RV, - Quality::FIXED, - ) -} +impl<'a> NonBlockingHandler for AccessControlCluster<'a> {} -fn attr_targets_per_entry_new() -> Attribute { - Attribute::new( - Attributes::TargetsPerEntry as u16, - AttrValue::Uint16(acl::TARGETS_PER_ENTRY as u16), - Access::RV, - Quality::FIXED, - ) -} - -fn attr_entries_per_fabric_new() -> Attribute { - Attribute::new( - Attributes::EntriesPerFabric as u16, - AttrValue::Uint16(acl::ENTRIES_PER_FABRIC as u16), - Access::RV, - Quality::FIXED, - ) +impl<'a> ChangeNotifier<()> for AccessControlCluster<'a> { + fn consume_change(&mut self) -> Option<()> { + self.data_ver.consume_change(()) + } } #[cfg(test)] mod tests { - use std::sync::Arc; + use core::cell::RefCell; use crate::{ acl::{AclEntry, AclMgr, AuthMode}, - data_model::{ - core::read::AttrReadEncoder, - objects::{AttrDetails, ClusterType, Privilege}, - }, + data_model::objects::{AttrDataEncoder, AttrDetails, Node, Privilege}, interaction_model::messages::ib::ListOperation, tlv::{get_root_node_struct, ElementType, TLVElement, TLVWriter, TagType, ToTLV}, - utils::writebuf::WriteBuf, + utils::{rand::dummy_rand, writebuf::WriteBuf}, }; use super::AccessControlCluster; @@ -209,16 +211,15 @@ mod tests { /// Add an ACL entry fn acl_cluster_add() { let mut buf: [u8; 100] = [0; 100]; - let buf_len = buf.len(); - let mut writebuf = WriteBuf::new(&mut buf, buf_len); + let mut writebuf = WriteBuf::new(&mut buf); let mut tw = TLVWriter::new(&mut writebuf); - let acl_mgr = Arc::new(AclMgr::new_with(false).unwrap()); - let mut acl = AccessControlCluster::new(acl_mgr.clone()).unwrap(); + let acl_mgr = RefCell::new(AclMgr::new()); + let mut acl = AccessControlCluster::new(&acl_mgr, dummy_rand); let new = AclEntry::new(2, Privilege::VIEW, AuthMode::Case); new.to_tlv(&mut tw, TagType::Anonymous).unwrap(); - let data = get_root_node_struct(writebuf.as_borrow_slice()).unwrap(); + let data = get_root_node_struct(writebuf.as_slice()).unwrap(); // Test, ACL has fabric index 2, but the accessing fabric is 1 // the fabric index in the TLV should be ignored and the ACL should be created with entry 1 @@ -227,8 +228,10 @@ mod tests { let verifier = AclEntry::new(1, Privilege::VIEW, AuthMode::Case); acl_mgr + .borrow() .for_each_acl(|a| { assert_eq!(*a, verifier); + Ok(()) }) .unwrap(); } @@ -237,25 +240,24 @@ mod tests { /// - The listindex used for edit should be relative to the current fabric fn acl_cluster_edit() { let mut buf: [u8; 100] = [0; 100]; - let buf_len = buf.len(); - let mut writebuf = WriteBuf::new(&mut buf, buf_len); + let mut writebuf = WriteBuf::new(&mut buf); let mut tw = TLVWriter::new(&mut writebuf); // Add 3 ACLs, belonging to fabric index 2, 1 and 2, in that order - let acl_mgr = Arc::new(AclMgr::new_with(false).unwrap()); + let acl_mgr = RefCell::new(AclMgr::new()); let mut verifier = [ AclEntry::new(2, Privilege::VIEW, AuthMode::Case), AclEntry::new(1, Privilege::VIEW, AuthMode::Case), AclEntry::new(2, Privilege::ADMIN, AuthMode::Case), ]; for i in verifier { - acl_mgr.add(i).unwrap(); + acl_mgr.borrow_mut().add(i).unwrap(); } - let mut acl = AccessControlCluster::new(acl_mgr.clone()).unwrap(); + let mut acl = AccessControlCluster::new(&acl_mgr, dummy_rand); let new = AclEntry::new(2, Privilege::VIEW, AuthMode::Case); new.to_tlv(&mut tw, TagType::Anonymous).unwrap(); - let data = get_root_node_struct(writebuf.as_borrow_slice()).unwrap(); + let data = get_root_node_struct(writebuf.as_slice()).unwrap(); // Test, Edit Fabric 2's index 1 - with accessing fabring as 2 - allow let result = acl.write_acl_attr(&ListOperation::EditItem(1), &data, 2); @@ -266,9 +268,11 @@ mod tests { // Also validate in the acl_mgr that the entries are in the right order let mut index = 0; acl_mgr + .borrow() .for_each_acl(|a| { assert_eq!(*a, verifier[index]); index += 1; + Ok(()) }) .unwrap(); } @@ -277,16 +281,16 @@ mod tests { /// - The listindex used for delete should be relative to the current fabric fn acl_cluster_delete() { // Add 3 ACLs, belonging to fabric index 2, 1 and 2, in that order - let acl_mgr = Arc::new(AclMgr::new_with(false).unwrap()); + let acl_mgr = RefCell::new(AclMgr::new()); let input = [ AclEntry::new(2, Privilege::VIEW, AuthMode::Case), AclEntry::new(1, Privilege::VIEW, AuthMode::Case), AclEntry::new(2, Privilege::ADMIN, AuthMode::Case), ]; for i in input { - acl_mgr.add(i).unwrap(); + acl_mgr.borrow_mut().add(i).unwrap(); } - let mut acl = AccessControlCluster::new(acl_mgr.clone()).unwrap(); + let mut acl = AccessControlCluster::new(&acl_mgr, dummy_rand); // data is don't-care actually let data = TLVElement::new(TagType::Anonymous, ElementType::True); @@ -298,9 +302,11 @@ mod tests { // Also validate in the acl_mgr that the entries are in the right order let mut index = 0; acl_mgr + .borrow() .for_each_acl(|a| { assert_eq!(*a, verifier[index]); index += 1; + Ok(()) }) .unwrap(); } @@ -309,84 +315,126 @@ mod tests { /// - acl read with and without fabric filtering fn acl_cluster_read() { let mut buf: [u8; 100] = [0; 100]; - let buf_len = buf.len(); - let mut writebuf = WriteBuf::new(&mut buf, buf_len); + let mut writebuf = WriteBuf::new(&mut buf); // Add 3 ACLs, belonging to fabric index 2, 1 and 2, in that order - let acl_mgr = Arc::new(AclMgr::new_with(false).unwrap()); + let acl_mgr = RefCell::new(AclMgr::new()); let input = [ AclEntry::new(2, Privilege::VIEW, AuthMode::Case), AclEntry::new(1, Privilege::VIEW, AuthMode::Case), AclEntry::new(2, Privilege::ADMIN, AuthMode::Case), ]; for i in input { - acl_mgr.add(i).unwrap(); + acl_mgr.borrow_mut().add(i).unwrap(); } - let acl = AccessControlCluster::new(acl_mgr).unwrap(); + let acl = AccessControlCluster::new(&acl_mgr, dummy_rand); // Test 1, all 3 entries are read in the response without fabric filtering { - let mut tw = TLVWriter::new(&mut writebuf); - let mut encoder = AttrReadEncoder::new(&mut tw); - let attr_details = AttrDetails { + let attr = AttrDetails { + node: &Node { + id: 0, + endpoints: &[], + }, + endpoint_id: 0, + cluster_id: 0, attr_id: 0, list_index: None, fab_idx: 1, fab_filter: false, + dataver: None, + wildcard: false, }; - acl.read_custom_attribute(&mut encoder, &attr_details); + + let mut tw = TLVWriter::new(&mut writebuf); + let encoder = AttrDataEncoder::new(&attr, &mut tw); + + acl.read(&attr, encoder).unwrap(); assert_eq!( + // &[ + // 21, 53, 1, 36, 0, 0, 55, 1, 24, 54, 2, 21, 36, 1, 1, 36, 2, 2, 54, 3, 24, 54, + // 4, 24, 36, 254, 2, 24, 21, 36, 1, 1, 36, 2, 2, 54, 3, 24, 54, 4, 24, 36, 254, + // 1, 24, 21, 36, 1, 5, 36, 2, 2, 54, 3, 24, 54, 4, 24, 36, 254, 2, 24, 24, 24, + // 24 + // ], &[ - 21, 53, 1, 36, 0, 0, 55, 1, 24, 54, 2, 21, 36, 1, 1, 36, 2, 2, 54, 3, 24, 54, - 4, 24, 36, 254, 2, 24, 21, 36, 1, 1, 36, 2, 2, 54, 3, 24, 54, 4, 24, 36, 254, - 1, 24, 21, 36, 1, 5, 36, 2, 2, 54, 3, 24, 54, 4, 24, 36, 254, 2, 24, 24, 24, - 24 + 21, 53, 1, 36, 0, 0, 55, 1, 36, 2, 0, 36, 3, 0, 36, 4, 0, 24, 54, 2, 21, 36, 1, + 1, 36, 2, 2, 54, 3, 24, 54, 4, 24, 36, 254, 2, 24, 21, 36, 1, 1, 36, 2, 2, 54, + 3, 24, 54, 4, 24, 36, 254, 1, 24, 21, 36, 1, 5, 36, 2, 2, 54, 3, 24, 54, 4, 24, + 36, 254, 2, 24, 24, 24, 24 ], - writebuf.as_borrow_slice() + writebuf.as_slice() ); } writebuf.reset(0); // Test 2, only single entry is read in the response with fabric filtering and fabric idx 1 { - let mut tw = TLVWriter::new(&mut writebuf); - let mut encoder = AttrReadEncoder::new(&mut tw); - - let attr_details = AttrDetails { + let attr = AttrDetails { + node: &Node { + id: 0, + endpoints: &[], + }, + endpoint_id: 0, + cluster_id: 0, attr_id: 0, list_index: None, fab_idx: 1, fab_filter: true, + dataver: None, + wildcard: false, }; - acl.read_custom_attribute(&mut encoder, &attr_details); + + let mut tw = TLVWriter::new(&mut writebuf); + let encoder = AttrDataEncoder::new(&attr, &mut tw); + + acl.read(&attr, encoder).unwrap(); assert_eq!( + // &[ + // 21, 53, 1, 36, 0, 0, 55, 1, 24, 54, 2, 21, 36, 1, 1, 36, 2, 2, 54, 3, 24, 54, + // 4, 24, 36, 254, 1, 24, 24, 24, 24 + // ], &[ - 21, 53, 1, 36, 0, 0, 55, 1, 24, 54, 2, 21, 36, 1, 1, 36, 2, 2, 54, 3, 24, 54, - 4, 24, 36, 254, 1, 24, 24, 24, 24 + 21, 53, 1, 36, 0, 0, 55, 1, 36, 2, 0, 36, 3, 0, 36, 4, 0, 24, 54, 2, 21, 36, 1, + 1, 36, 2, 2, 54, 3, 24, 54, 4, 24, 36, 254, 1, 24, 24, 24, 24 ], - writebuf.as_borrow_slice() + writebuf.as_slice() ); } writebuf.reset(0); // Test 3, only single entry is read in the response with fabric filtering and fabric idx 2 { - let mut tw = TLVWriter::new(&mut writebuf); - let mut encoder = AttrReadEncoder::new(&mut tw); - - let attr_details = AttrDetails { + let attr = AttrDetails { + node: &Node { + id: 0, + endpoints: &[], + }, + endpoint_id: 0, + cluster_id: 0, attr_id: 0, list_index: None, fab_idx: 2, fab_filter: true, + dataver: None, + wildcard: false, }; - acl.read_custom_attribute(&mut encoder, &attr_details); + + let mut tw = TLVWriter::new(&mut writebuf); + let encoder = AttrDataEncoder::new(&attr, &mut tw); + + acl.read(&attr, encoder).unwrap(); assert_eq!( + // &[ + // 21, 53, 1, 36, 0, 0, 55, 1, 24, 54, 2, 21, 36, 1, 1, 36, 2, 2, 54, 3, 24, 54, + // 4, 24, 36, 254, 2, 24, 21, 36, 1, 5, 36, 2, 2, 54, 3, 24, 54, 4, 24, 36, 254, + // 2, 24, 24, 24, 24 + // ], &[ - 21, 53, 1, 36, 0, 0, 55, 1, 24, 54, 2, 21, 36, 1, 1, 36, 2, 2, 54, 3, 24, 54, - 4, 24, 36, 254, 2, 24, 21, 36, 1, 5, 36, 2, 2, 54, 3, 24, 54, 4, 24, 36, 254, - 2, 24, 24, 24, 24 + 21, 53, 1, 36, 0, 0, 55, 1, 36, 2, 0, 36, 3, 0, 36, 4, 0, 24, 54, 2, 21, 36, 1, + 1, 36, 2, 2, 54, 3, 24, 54, 4, 24, 36, 254, 2, 24, 21, 36, 1, 5, 36, 2, 2, 54, + 3, 24, 54, 4, 24, 36, 254, 2, 24, 24, 24, 24 ], - writebuf.as_borrow_slice() + writebuf.as_slice() ); } } diff --git a/matter/src/data_model/system_model/descriptor.rs b/matter/src/data_model/system_model/descriptor.rs index 4fba0fa2..2df17f57 100644 --- a/matter/src/data_model/system_model/descriptor.rs +++ b/matter/src/data_model/system_model/descriptor.rs @@ -15,18 +15,20 @@ * limitations under the License. */ -use num_derive::FromPrimitive; +use core::convert::TryInto; -use crate::data_model::core::DataModel; +use strum::FromRepr; + +use crate::attribute_enum; use crate::data_model::objects::*; -use crate::error::*; -use crate::interaction_model::messages::GenericPath; +use crate::error::Error; use crate::tlv::{TLVWriter, TagType, ToTLV}; -use log::error; +use crate::utils::rand::Rand; pub const ID: u32 = 0x001D; -#[derive(FromPrimitive)] +#[derive(FromRepr)] +#[repr(u16)] #[allow(clippy::enum_variant_names)] pub enum Attributes { DeviceTypeList = 0, @@ -35,134 +37,155 @@ pub enum Attributes { PartsList = 3, } +attribute_enum!(Attributes); + +pub const CLUSTER: Cluster<'static> = Cluster { + id: ID as _, + feature_map: 0, + attributes: &[ + FEATURE_MAP, + ATTRIBUTE_LIST, + Attribute::new(Attributes::DeviceTypeList as u16, Access::RV, Quality::NONE), + Attribute::new(Attributes::ServerList as u16, Access::RV, Quality::NONE), + Attribute::new(Attributes::PartsList as u16, Access::RV, Quality::NONE), + Attribute::new(Attributes::ClientList as u16, Access::RV, Quality::NONE), + ], + commands: &[], +}; + pub struct DescriptorCluster { - base: Cluster, - endpoint_id: EndptId, - data_model: DataModel, + data_ver: Dataver, } impl DescriptorCluster { - pub fn new(endpoint_id: EndptId, data_model: DataModel) -> Result, Error> { - let mut c = Box::new(DescriptorCluster { - endpoint_id, - data_model, - base: Cluster::new(ID)?, - }); - let attrs = [ - Attribute::new( - Attributes::DeviceTypeList as u16, - AttrValue::Custom, - Access::RV, - Quality::NONE, - ), - Attribute::new( - Attributes::ServerList as u16, - AttrValue::Custom, - Access::RV, - Quality::NONE, - ), - Attribute::new( - Attributes::PartsList as u16, - AttrValue::Custom, - Access::RV, - Quality::NONE, - ), - Attribute::new( - Attributes::ClientList as u16, - AttrValue::Custom, - Access::RV, - Quality::NONE, - ), - ]; - c.base.add_attributes(&attrs[..])?; - Ok(c) + pub fn new(rand: Rand) -> Self { + Self { + data_ver: Dataver::new(rand), + } } - fn encode_devtype_list(&self, tag: TagType, tw: &mut TLVWriter) { - let path = GenericPath { - endpoint: Some(self.endpoint_id), - cluster: None, - leaf: None, - }; - let _ = tw.start_array(tag); - let dm = self.data_model.node.read().unwrap(); - let _ = dm.for_each_endpoint(&path, |_, e| { - let dev_type = e.get_dev_type(); - let _ = dev_type.to_tlv(tw, TagType::Anonymous); + pub fn read(&self, attr: &AttrDetails, encoder: AttrDataEncoder) -> Result<(), Error> { + if let Some(mut writer) = encoder.with_dataver(self.data_ver.get())? { + if attr.is_system() { + CLUSTER.read(attr.attr_id, writer) + } else { + match attr.attr_id.try_into()? { + Attributes::DeviceTypeList => { + self.encode_devtype_list(attr.node, AttrDataWriter::TAG, &mut writer)?; + writer.complete() + } + Attributes::ServerList => { + self.encode_server_list( + attr.node, + attr.endpoint_id, + AttrDataWriter::TAG, + &mut writer, + )?; + writer.complete() + } + Attributes::PartsList => { + self.encode_parts_list( + attr.node, + attr.endpoint_id, + AttrDataWriter::TAG, + &mut writer, + )?; + writer.complete() + } + Attributes::ClientList => { + self.encode_client_list( + attr.node, + attr.endpoint_id, + AttrDataWriter::TAG, + &mut writer, + )?; + writer.complete() + } + } + } + } else { Ok(()) - }); - let _ = tw.end_container(); + } } - fn encode_server_list(&self, tag: TagType, tw: &mut TLVWriter) { - let path = GenericPath { - endpoint: Some(self.endpoint_id), - cluster: None, - leaf: None, - }; - let _ = tw.start_array(tag); - let dm = self.data_model.node.read().unwrap(); - let _ = dm.for_each_cluster(&path, |_current_path, c| { - let _ = tw.u32(TagType::Anonymous, c.base().id()); - Ok(()) - }); - let _ = tw.end_container(); + fn encode_devtype_list( + &self, + node: &Node, + tag: TagType, + tw: &mut TLVWriter, + ) -> Result<(), Error> { + tw.start_array(tag)?; + for endpoint in node.endpoints { + let dev_type = endpoint.device_type; + dev_type.to_tlv(tw, TagType::Anonymous)?; + } + + tw.end_container() } - fn encode_parts_list(&self, tag: TagType, tw: &mut TLVWriter) { - let path = GenericPath { - endpoint: None, - cluster: None, - leaf: None, - }; - let _ = tw.start_array(tag); - if self.endpoint_id == 0 { + fn encode_server_list( + &self, + node: &Node, + endpoint_id: u16, + tag: TagType, + tw: &mut TLVWriter, + ) -> Result<(), Error> { + tw.start_array(tag)?; + for endpoint in node.endpoints { + if endpoint.id == endpoint_id { + for cluster in endpoint.clusters { + tw.u32(TagType::Anonymous, cluster.id as _)?; + } + } + } + + tw.end_container() + } + + fn encode_parts_list( + &self, + node: &Node, + endpoint_id: u16, + tag: TagType, + tw: &mut TLVWriter, + ) -> Result<(), Error> { + tw.start_array(tag)?; + + if endpoint_id == 0 { // TODO: If endpoint is another than 0, need to figure out what to do - let dm = self.data_model.node.read().unwrap(); - let _ = dm.for_each_endpoint(&path, |current_path, _| { - if let Some(endpoint_id) = current_path.endpoint { - if endpoint_id != 0 { - let _ = tw.u16(TagType::Anonymous, endpoint_id); - } + for endpoint in node.endpoints { + if endpoint.id != 0 { + tw.u16(TagType::Anonymous, endpoint.id)?; } - Ok(()) - }); + } } - let _ = tw.end_container(); + + tw.end_container() } - fn encode_client_list(&self, tag: TagType, tw: &mut TLVWriter) { + fn encode_client_list( + &self, + _node: &Node, + _endpoint_id: u16, + tag: TagType, + tw: &mut TLVWriter, + ) -> Result<(), Error> { // No Clients supported - let _ = tw.start_array(tag); - let _ = tw.end_container(); + tw.start_array(tag)?; + tw.end_container() } } -impl ClusterType for DescriptorCluster { - fn base(&self) -> &Cluster { - &self.base - } - fn base_mut(&mut self) -> &mut Cluster { - &mut self.base +impl Handler for DescriptorCluster { + fn read(&self, attr: &AttrDetails, encoder: AttrDataEncoder) -> Result<(), Error> { + DescriptorCluster::read(self, attr, encoder) } +} - fn read_custom_attribute(&self, encoder: &mut dyn Encoder, attr: &AttrDetails) { - match num::FromPrimitive::from_u16(attr.attr_id) { - Some(Attributes::DeviceTypeList) => encoder.encode(EncodeValue::Closure(&|tag, tw| { - self.encode_devtype_list(tag, tw) - })), - Some(Attributes::ServerList) => encoder.encode(EncodeValue::Closure(&|tag, tw| { - self.encode_server_list(tag, tw) - })), - Some(Attributes::PartsList) => encoder.encode(EncodeValue::Closure(&|tag, tw| { - self.encode_parts_list(tag, tw) - })), - Some(Attributes::ClientList) => encoder.encode(EncodeValue::Closure(&|tag, tw| { - self.encode_client_list(tag, tw) - })), - _ => { - error!("Attribute not supported: this shouldn't happen"); - } - } +impl NonBlockingHandler for DescriptorCluster {} + +impl ChangeNotifier<()> for DescriptorCluster { + fn consume_change(&mut self) -> Option<()> { + self.data_ver.consume_change(()) } } diff --git a/matter/src/error.rs b/matter/src/error.rs index 07cd681f..3a54b2c7 100644 --- a/matter/src/error.rs +++ b/matter/src/error.rs @@ -15,13 +15,14 @@ * limitations under the License. */ -use std::{ - array::TryFromSliceError, fmt, string::FromUtf8Error, sync::PoisonError, time::SystemTimeError, -}; +use alloc::string::FromUtf8Error; +use core::{array::TryFromSliceError, fmt}; use async_channel::{SendError, TryRecvError}; use log::error; +extern crate alloc; + #[derive(Debug, PartialEq, Clone, Copy)] pub enum Error { AttributeNotFound, @@ -31,6 +32,13 @@ pub enum Error { CommandNotFound, Duplicate, EndpointNotFound, + InvalidAction, + InvalidCommand, + InvalidDataType, + UnsupportedAccess, + ResourceExhausted, + Busy, + DataVersionMismatch, Crypto, TLSStack, MdnsError, @@ -71,6 +79,36 @@ pub enum Error { Utf8Fail, } +impl Error { + pub fn remap(self, matcher: F, to: Self) -> Self + where + F: FnOnce(&Self) -> bool, + { + if matcher(&self) { + to + } else { + self + } + } + + pub fn map_invalid(self, to: Self) -> Self { + self.remap(|e| matches!(e, Self::Invalid | Self::InvalidData), to) + } + + pub fn map_invalid_command(self) -> Self { + self.map_invalid(Error::InvalidCommand) + } + + pub fn map_invalid_action(self) -> Self { + self.map_invalid(Error::InvalidAction) + } + + pub fn map_invalid_data_type(self) -> Self { + self.map_invalid(Error::InvalidDataType) + } +} + +#[cfg(feature = "std")] impl From for Error { fn from(_e: std::io::Error) -> Self { // Keep things simple for now @@ -78,8 +116,9 @@ impl From for Error { } } -impl From> for Error { - fn from(_e: PoisonError) -> Self { +#[cfg(feature = "std")] +impl From> for Error { + fn from(_e: std::sync::PoisonError) -> Self { Self::RwLock } } @@ -107,8 +146,9 @@ impl From for Error { } } -impl From for Error { - fn from(_e: SystemTimeError) -> Self { +#[cfg(feature = "std")] +impl From for Error { + fn from(_e: std::time::SystemTimeError) -> Self { Self::SysTimeFail } } diff --git a/matter/src/fabric.rs b/matter/src/fabric.rs index 1715db76..42c55fdf 100644 --- a/matter/src/fabric.rs +++ b/matter/src/fabric.rs @@ -15,29 +15,33 @@ * limitations under the License. */ -use std::sync::{Arc, Mutex, MutexGuard, RwLock}; +use core::fmt::Write; use byteorder::{BigEndian, ByteOrder, LittleEndian}; use log::{error, info}; -use owning_ref::RwLockReadGuardRef; use crate::{ cert::Cert, - crypto::{self, crypto_dummy::KeyPairDummy, hkdf_sha256, CryptoKeyPair, HmacSha256, KeyPair}, + crypto::{self, hkdf_sha256, HmacSha256, KeyPair}, error::Error, group_keys::KeySet, - mdns::{self, Mdns}, - sys::{Psm, SysMdnsService}, + mdns::{MdnsMgr, ServiceMode}, + persist::Psm, tlv::{OctetStr, TLVWriter, TagType, ToTLV, UtfStr}, }; -const MAX_CERT_TLV_LEN: usize = 350; +const MAX_CERT_TLV_LEN: usize = 300; const COMPRESSED_FABRIC_ID_LEN: usize = 8; macro_rules! fb_key { - ($index:ident, $key:ident) => { - &format!("fb{}{}", $index, $key) - }; + ($index:ident, $key:ident, $buf:expr) => {{ + use core::fmt::Write; + + $buf = "".into(); + write!(&mut $buf, "fb{}{}", $index, $key).unwrap(); + + &$buf + }}; } const ST_VID: &str = "vid"; @@ -50,20 +54,6 @@ const ST_PBKEY: &str = "pubkey"; const ST_PRKEY: &str = "privkey"; #[allow(dead_code)] -pub struct Fabric { - node_id: u64, - fabric_id: u64, - vendor_id: u16, - key_pair: Box, - pub root_ca: Cert, - pub icac: Option, - pub noc: Cert, - pub ipk: KeySet, - label: String, - compressed_id: [u8; COMPRESSED_FABRIC_ID_LEN], - mdns_service: Option, -} - #[derive(ToTLV)] #[tlvargs(lifetime = "'a", start = 1)] pub struct FabricDescriptor<'a> { @@ -77,6 +67,19 @@ pub struct FabricDescriptor<'a> { pub fab_idx: Option, } +pub struct Fabric { + node_id: u64, + fabric_id: u64, + vendor_id: u16, + key_pair: KeyPair, + pub root_ca: Cert, + pub icac: Option, + pub noc: Cert, + pub ipk: KeySet, + label: heapless::String<32>, + mdns_service_name: heapless::String<33>, +} + impl Fabric { pub fn new( key_pair: KeyPair, @@ -85,56 +88,43 @@ impl Fabric { noc: Cert, ipk: &[u8], vendor_id: u16, + label: &str, ) -> Result { let node_id = noc.get_node_id()?; let fabric_id = noc.get_fabric_id()?; - let mut f = Self { - node_id, - fabric_id, - vendor_id, - key_pair: Box::new(key_pair), - root_ca, - icac, - noc, - ipk: KeySet::default(), - compressed_id: [0; COMPRESSED_FABRIC_ID_LEN], - label: "".into(), - mdns_service: None, - }; - Fabric::get_compressed_id(f.root_ca.get_pubkey(), fabric_id, &mut f.compressed_id)?; - f.ipk = KeySet::new(ipk, &f.compressed_id)?; + let mut compressed_id = [0_u8; COMPRESSED_FABRIC_ID_LEN]; + + Fabric::get_compressed_id(root_ca.get_pubkey(), fabric_id, &mut compressed_id)?; + let ipk = KeySet::new(ipk, &compressed_id)?; - let mut mdns_service_name = String::with_capacity(33); - for c in f.compressed_id { - mdns_service_name.push_str(&format!("{:02X}", c)); + let mut mdns_service_name = heapless::String::<33>::new(); + for c in compressed_id { + let mut hex = heapless::String::<4>::new(); + write!(&mut hex, "{:02X}", c).unwrap(); + mdns_service_name.push_str(&hex).unwrap(); } - mdns_service_name.push('-'); + mdns_service_name.push('-').unwrap(); let mut node_id_be: [u8; 8] = [0; 8]; BigEndian::write_u64(&mut node_id_be, node_id); for c in node_id_be { - mdns_service_name.push_str(&format!("{:02X}", c)); + let mut hex = heapless::String::<4>::new(); + write!(&mut hex, "{:02X}", c).unwrap(); + mdns_service_name.push_str(&hex).unwrap(); } info!("MDNS Service Name: {}", mdns_service_name); - f.mdns_service = Some( - Mdns::get()?.publish_service(&mdns_service_name, mdns::ServiceMode::Commissioned)?, - ); - Ok(f) - } - pub fn dummy() -> Result { Ok(Self { - node_id: 0, - fabric_id: 0, - vendor_id: 0, - key_pair: Box::new(KeyPairDummy::new()?), - root_ca: Cert::default(), - icac: Some(Cert::default()), - noc: Cert::default(), - ipk: KeySet::default(), - label: "".into(), - compressed_id: [0; COMPRESSED_FABRIC_ID_LEN], - mdns_service: None, + node_id, + fabric_id, + vendor_id, + key_pair, + root_ca, + icac, + noc, + ipk, + label: label.into(), + mdns_service_name, }) } @@ -195,164 +185,362 @@ impl Fabric { } } - fn rm_store(&self, index: usize, psm: &MutexGuard) { - psm.rm(fb_key!(index, ST_RCA)); - psm.rm(fb_key!(index, ST_ICA)); - psm.rm(fb_key!(index, ST_NOC)); - psm.rm(fb_key!(index, ST_IPK)); - psm.rm(fb_key!(index, ST_LBL)); - psm.rm(fb_key!(index, ST_PBKEY)); - psm.rm(fb_key!(index, ST_PRKEY)); - psm.rm(fb_key!(index, ST_VID)); - } + fn store(&self, index: usize, mut psm: T) -> Result<(), Error> + where + T: Psm, + { + let mut _kb = heapless::String::<32>::new(); - fn store(&self, index: usize, psm: &MutexGuard) -> Result<(), Error> { - let mut key = [0u8; MAX_CERT_TLV_LEN]; - let len = self.root_ca.as_tlv(&mut key)?; - psm.set_kv_slice(fb_key!(index, ST_RCA), &key[..len])?; + let mut buf = [0u8; MAX_CERT_TLV_LEN]; + let len = self.root_ca.as_tlv(&mut buf)?; + psm.set_kv_slice(fb_key!(index, ST_RCA, _kb), &buf[..len])?; let len = if let Some(icac) = &self.icac { - icac.as_tlv(&mut key)? + icac.as_tlv(&mut buf)? } else { 0 }; - psm.set_kv_slice(fb_key!(index, ST_ICA), &key[..len])?; + psm.set_kv_slice(fb_key!(index, ST_ICA, _kb), &buf[..len])?; - let len = self.noc.as_tlv(&mut key)?; - psm.set_kv_slice(fb_key!(index, ST_NOC), &key[..len])?; - psm.set_kv_slice(fb_key!(index, ST_IPK), self.ipk.epoch_key())?; - psm.set_kv_slice(fb_key!(index, ST_LBL), self.label.as_bytes())?; + let len = self.noc.as_tlv(&mut buf)?; + psm.set_kv_slice(fb_key!(index, ST_NOC, _kb), &buf[..len])?; + psm.set_kv_slice(fb_key!(index, ST_IPK, _kb), self.ipk.epoch_key())?; + psm.set_kv_slice(fb_key!(index, ST_LBL, _kb), self.label.as_bytes())?; - let mut key = [0_u8; crypto::EC_POINT_LEN_BYTES]; - let len = self.key_pair.get_public_key(&mut key)?; - let key = &key[..len]; - psm.set_kv_slice(fb_key!(index, ST_PBKEY), key)?; + let mut buf = [0_u8; crypto::EC_POINT_LEN_BYTES]; + let len = self.key_pair.get_public_key(&mut buf)?; + let key = &buf[..len]; + psm.set_kv_slice(fb_key!(index, ST_PBKEY, _kb), key)?; - let mut key = [0_u8; crypto::BIGNUM_LEN_BYTES]; - let len = self.key_pair.get_private_key(&mut key)?; - let key = &key[..len]; - psm.set_kv_slice(fb_key!(index, ST_PRKEY), key)?; + let mut buf = [0_u8; crypto::BIGNUM_LEN_BYTES]; + let len = self.key_pair.get_private_key(&mut buf)?; + let key = &buf[..len]; + psm.set_kv_slice(fb_key!(index, ST_PRKEY, _kb), key)?; - psm.set_kv_u64(fb_key!(index, ST_VID), self.vendor_id.into())?; + psm.set_kv_u64(fb_key!(index, ST_VID, _kb), self.vendor_id.into())?; Ok(()) } - fn load(index: usize, psm: &MutexGuard) -> Result { - let mut root_ca = Vec::new(); - psm.get_kv_slice(fb_key!(index, ST_RCA), &mut root_ca)?; - let root_ca = Cert::new(root_ca.as_slice())?; + fn load(index: usize, psm: T) -> Result + where + T: Psm, + { + let mut _kb = heapless::String::<32>::new(); + + let mut buf = [0u8; MAX_CERT_TLV_LEN]; + let root_ca = psm.get_kv_slice(fb_key!(index, ST_RCA, _kb), &mut buf)?; + let root_ca = Cert::new(root_ca)?; - let mut icac = Vec::new(); - psm.get_kv_slice(fb_key!(index, ST_ICA), &mut icac)?; + let icac = psm.get_kv_slice(fb_key!(index, ST_ICA, _kb), &mut buf)?; let icac = if !icac.is_empty() { - Some(Cert::new(icac.as_slice())?) + Some(Cert::new(icac)?) } else { None }; - let mut noc = Vec::new(); - psm.get_kv_slice(fb_key!(index, ST_NOC), &mut noc)?; - let noc = Cert::new(noc.as_slice())?; + let noc = psm.get_kv_slice(fb_key!(index, ST_NOC, _kb), &mut buf)?; + let noc = Cert::new(noc)?; - let mut ipk = Vec::new(); - psm.get_kv_slice(fb_key!(index, ST_IPK), &mut ipk)?; + let label = psm.get_kv_slice(fb_key!(index, ST_LBL, _kb), &mut buf)?; + let label: heapless::String<32> = core::str::from_utf8(label) + .map_err(|_| { + error!("Couldn't read label"); + Error::Invalid + })? + .into(); - let mut label = Vec::new(); - psm.get_kv_slice(fb_key!(index, ST_LBL), &mut label)?; - let label = String::from_utf8(label).map_err(|_| { - error!("Couldn't read label"); - Error::Invalid - })?; + let ipk = psm.get_kv_slice(fb_key!(index, ST_IPK, _kb), &mut buf)?; - let mut pub_key = Vec::new(); - psm.get_kv_slice(fb_key!(index, ST_PBKEY), &mut pub_key)?; - let mut priv_key = Vec::new(); - psm.get_kv_slice(fb_key!(index, ST_PRKEY), &mut priv_key)?; - let keypair = KeyPair::new_from_components(pub_key.as_slice(), priv_key.as_slice())?; + let mut buf = [0_u8; crypto::EC_POINT_LEN_BYTES]; + let pub_key = psm.get_kv_slice(fb_key!(index, ST_PBKEY, _kb), &mut buf)?; - let mut vendor_id = 0; - psm.get_kv_u64(fb_key!(index, ST_VID), &mut vendor_id)?; + let mut buf = [0_u8; crypto::BIGNUM_LEN_BYTES]; + let priv_key = psm.get_kv_slice(fb_key!(index, ST_PRKEY, _kb), &mut buf)?; + let keypair = KeyPair::new_from_components(pub_key, priv_key)?; - let f = Fabric::new( - keypair, - root_ca, - icac, - noc, - ipk.as_slice(), - vendor_id as u16, - ); - f.map(|mut f| { - f.label = label; - f - }) + let vendor_id = psm.get_kv_u64(fb_key!(index, ST_VID, _kb))?; + + Fabric::new(keypair, root_ca, icac, noc, ipk, vendor_id as u16, &label) + } + + fn remove(index: usize, mut psm: T) -> Result<(), Error> + where + T: Psm, + { + let mut _kb = heapless::String::<32>::new(); + + psm.remove(fb_key!(index, ST_RCA, _kb))?; + psm.remove(fb_key!(index, ST_ICA, _kb))?; + + psm.remove(fb_key!(index, ST_NOC, _kb))?; + + psm.remove(fb_key!(index, ST_LBL, _kb))?; + + psm.remove(fb_key!(index, ST_IPK, _kb))?; + + psm.remove(fb_key!(index, ST_PBKEY, _kb))?; + psm.remove(fb_key!(index, ST_PRKEY, _kb))?; + + psm.remove(fb_key!(index, ST_VID, _kb))?; + + Ok(()) + } + + #[cfg(feature = "nightly")] + async fn store_async(&self, index: usize, mut psm: T) -> Result<(), Error> + where + T: crate::persist::asynch::AsyncPsm, + { + let mut _kb = heapless::String::<32>::new(); + + let mut buf = [0u8; MAX_CERT_TLV_LEN]; + let len = self.root_ca.as_tlv(&mut buf)?; + psm.set_kv_slice(fb_key!(index, ST_RCA, _kb), &buf[..len]) + .await?; + + let len = if let Some(icac) = &self.icac { + icac.as_tlv(&mut buf)? + } else { + 0 + }; + psm.set_kv_slice(fb_key!(index, ST_ICA, _kb), &buf[..len]) + .await?; + + let len = self.noc.as_tlv(&mut buf)?; + psm.set_kv_slice(fb_key!(index, ST_NOC, _kb), &buf[..len]) + .await?; + psm.set_kv_slice(fb_key!(index, ST_IPK, _kb), self.ipk.epoch_key()) + .await?; + psm.set_kv_slice(fb_key!(index, ST_LBL, _kb), self.label.as_bytes()) + .await?; + + let mut buf = [0_u8; crypto::EC_POINT_LEN_BYTES]; + let len = self.key_pair.get_public_key(&mut buf)?; + let key = &buf[..len]; + psm.set_kv_slice(fb_key!(index, ST_PBKEY, _kb), key).await?; + + let mut buf = [0_u8; crypto::BIGNUM_LEN_BYTES]; + let len = self.key_pair.get_private_key(&mut buf)?; + let key = &buf[..len]; + psm.set_kv_slice(fb_key!(index, ST_PRKEY, _kb), key).await?; + + psm.set_kv_u64(fb_key!(index, ST_VID, _kb), self.vendor_id.into()) + .await?; + Ok(()) + } + + #[cfg(feature = "nightly")] + async fn load_async(index: usize, psm: T) -> Result + where + T: crate::persist::asynch::AsyncPsm, + { + let mut _kb = heapless::String::<32>::new(); + + let mut buf = [0u8; MAX_CERT_TLV_LEN]; + let root_ca = psm + .get_kv_slice(fb_key!(index, ST_RCA, _kb), &mut buf) + .await?; + let root_ca = Cert::new(root_ca)?; + + let icac = psm + .get_kv_slice(fb_key!(index, ST_ICA, _kb), &mut buf) + .await?; + let icac = if !icac.is_empty() { + Some(Cert::new(icac)?) + } else { + None + }; + + let noc = psm + .get_kv_slice(fb_key!(index, ST_NOC, _kb), &mut buf) + .await?; + let noc = Cert::new(noc)?; + + let label = psm + .get_kv_slice(fb_key!(index, ST_LBL, _kb), &mut buf) + .await?; + let label: heapless::String<32> = core::str::from_utf8(label) + .map_err(|_| { + error!("Couldn't read label"); + Error::Invalid + })? + .into(); + + let ipk = psm + .get_kv_slice(fb_key!(index, ST_IPK, _kb), &mut buf) + .await?; + + let mut buf = [0_u8; crypto::EC_POINT_LEN_BYTES]; + let pub_key = psm + .get_kv_slice(fb_key!(index, ST_PBKEY, _kb), &mut buf) + .await?; + + let mut buf = [0_u8; crypto::BIGNUM_LEN_BYTES]; + let priv_key = psm + .get_kv_slice(fb_key!(index, ST_PRKEY, _kb), &mut buf) + .await?; + let keypair = KeyPair::new_from_components(pub_key, priv_key)?; + + let vendor_id = psm.get_kv_u64(fb_key!(index, ST_VID, _kb)).await?; + + Fabric::new(keypair, root_ca, icac, noc, ipk, vendor_id as u16, &label) + } + + #[cfg(feature = "nightly")] + async fn remove_async(index: usize, mut psm: T) -> Result<(), Error> + where + T: crate::persist::asynch::AsyncPsm, + { + let mut _kb = heapless::String::<32>::new(); + + psm.remove(fb_key!(index, ST_RCA, _kb)).await?; + psm.remove(fb_key!(index, ST_ICA, _kb)).await?; + + psm.remove(fb_key!(index, ST_NOC, _kb)).await?; + + psm.remove(fb_key!(index, ST_LBL, _kb)).await?; + + psm.remove(fb_key!(index, ST_IPK, _kb)).await?; + + psm.remove(fb_key!(index, ST_PBKEY, _kb)).await?; + psm.remove(fb_key!(index, ST_PRKEY, _kb)).await?; + + psm.remove(fb_key!(index, ST_VID, _kb)).await?; + + Ok(()) } } pub const MAX_SUPPORTED_FABRICS: usize = 3; -#[derive(Default)] -pub struct FabricMgrInner { - // The outside world expects Fabric Index to be one more than the actual one - // since 0 is not allowed. Need to handle this cleanly somehow - pub fabrics: [Option; MAX_SUPPORTED_FABRICS], -} pub struct FabricMgr { - inner: RwLock, - psm: Arc>, + // The outside world expects Fabric Index to be one more than the actual one + // since 0 is not allowed. Need to handle this cleanly somehow + fabrics: [Option; MAX_SUPPORTED_FABRICS], + changed: bool, } impl FabricMgr { - pub fn new() -> Result { - let dummy_fabric = Fabric::dummy()?; - let mut mgr = FabricMgrInner::default(); - mgr.fabrics[0] = Some(dummy_fabric); - let mut fm = Self { - inner: RwLock::new(mgr), - psm: Psm::get()?, - }; - fm.load()?; - Ok(fm) + pub const fn new() -> Self { + const INIT: Option = None; + + Self { + fabrics: [INIT; MAX_SUPPORTED_FABRICS], + changed: false, + } } - fn store(&self, index: usize, fabric: &Fabric) -> Result<(), Error> { - let psm = self.psm.lock().unwrap(); - fabric.store(index, &psm) + pub fn store(&mut self, mut psm: T) -> Result<(), Error> + where + T: Psm, + { + if self.changed { + for i in 1..MAX_SUPPORTED_FABRICS { + if let Some(fabric) = self.fabrics[i].as_mut() { + info!("Storing fabric at index {}", i); + fabric.store(i, &mut psm)?; + } else { + let _ = Fabric::remove(i, &mut psm); + } + } + + self.changed = false; + } + + Ok(()) } - fn load(&mut self) -> Result<(), Error> { - let mut mgr = self.inner.write()?; - let psm = self.psm.lock().unwrap(); - for i in 0..MAX_SUPPORTED_FABRICS { - let result = Fabric::load(i, &psm); + pub fn load(&mut self, mut psm: T, mdns_mgr: &mut MdnsMgr) -> Result<(), Error> + where + T: Psm, + { + for i in 1..MAX_SUPPORTED_FABRICS { + let result = Fabric::load(i, &mut psm); if let Ok(fabric) = result { info!("Adding new fabric at index {}", i); - mgr.fabrics[i] = Some(fabric); + self.fabrics[i] = Some(fabric); + mdns_mgr.publish_service( + &self.fabrics[i].as_ref().unwrap().mdns_service_name, + ServiceMode::Commissioned, + )?; + } else { + self.fabrics[i] = None; } } + + self.changed = false; + Ok(()) } - pub fn add(&self, f: Fabric) -> Result { - let mut mgr = self.inner.write()?; - let index = mgr + #[cfg(feature = "nightly")] + pub async fn store_async(&mut self, mut psm: T) -> Result<(), Error> + where + T: crate::persist::asynch::AsyncPsm, + { + if self.changed { + for i in 1..MAX_SUPPORTED_FABRICS { + if let Some(fabric) = self.fabrics[i].as_mut() { + info!("Storing fabric at index {}", i); + fabric.store_async(i, &mut psm).await?; + } else { + let _ = Fabric::remove_async(i, &mut psm).await; + } + } + + self.changed = false; + } + + Ok(()) + } + + #[cfg(feature = "nightly")] + pub async fn load_async(&mut self, mut psm: T, mdns_mgr: &mut MdnsMgr) -> Result<(), Error> + where + T: crate::persist::asynch::AsyncPsm, + { + for i in 1..MAX_SUPPORTED_FABRICS { + let result = Fabric::load_async(i, &mut psm).await; + if let Ok(fabric) = result { + info!("Adding new fabric at index {}", i); + self.fabrics[i] = Some(fabric); + mdns_mgr.publish_service( + &self.fabrics[i].as_ref().unwrap().mdns_service_name, + ServiceMode::Commissioned, + )?; + } else { + self.fabrics[i] = None; + } + } + + self.changed = false; + + Ok(()) + } + + pub fn add(&mut self, f: Fabric, mdns_mgr: &mut MdnsMgr) -> Result { + let index = self .fabrics .iter() + .skip(1) .position(|f| f.is_none()) .ok_or(Error::NoSpace)?; - self.store(index, &f)?; + self.fabrics[index] = Some(f); + mdns_mgr.publish_service( + &self.fabrics[index].as_ref().unwrap().mdns_service_name, + ServiceMode::Commissioned, + )?; + + self.changed = true; - mgr.fabrics[index] = Some(f); Ok(index as u8) } - pub fn remove(&self, fab_idx: u8) -> Result<(), Error> { - let fab_idx = fab_idx as usize; - let mut mgr = self.inner.write().unwrap(); - let psm = self.psm.lock().unwrap(); - if let Some(f) = &mgr.fabrics[fab_idx] { - f.rm_store(fab_idx, &psm); - mgr.fabrics[fab_idx] = None; + pub fn remove(&mut self, fab_idx: u8, mdns_mgr: &mut MdnsMgr) -> Result<(), Error> { + if let Some(f) = self.fabrics[fab_idx as usize].take() { + mdns_mgr.unpublish_service(&f.mdns_service_name, ServiceMode::Commissioned)?; + self.changed = true; Ok(()) } else { Err(Error::NotFound) @@ -360,9 +548,8 @@ impl FabricMgr { } pub fn match_dest_id(&self, random: &[u8], target: &[u8]) -> Result { - let mgr = self.inner.read()?; - for i in 0..MAX_SUPPORTED_FABRICS { - if let Some(fabric) = &mgr.fabrics[i] { + for i in 1..MAX_SUPPORTED_FABRICS { + if let Some(fabric) = &self.fabrics[i] { if fabric.match_dest_id(random, target).is_ok() { return Ok(i); } @@ -371,17 +558,13 @@ impl FabricMgr { Err(Error::NotFound) } - pub fn get_fabric<'ret, 'me: 'ret>( - &'me self, - idx: usize, - ) -> Result>, Error> { - Ok(RwLockReadGuardRef::new(self.inner.read()?).map(|fm| &fm.fabrics[idx])) + pub fn get_fabric(&self, idx: usize) -> Result, Error> { + Ok(self.fabrics[idx].as_ref()) } pub fn is_empty(&self) -> bool { - let mgr = self.inner.read().unwrap(); for i in 1..MAX_SUPPORTED_FABRICS { - if mgr.fabrics[i].is_some() { + if self.fabrics[i].is_some() { return false; } } @@ -389,10 +572,9 @@ impl FabricMgr { } pub fn used_count(&self) -> usize { - let mgr = self.inner.read().unwrap(); let mut count = 0; for i in 1..MAX_SUPPORTED_FABRICS { - if mgr.fabrics[i].is_some() { + if self.fabrics[i].is_some() { count += 1; } } @@ -402,37 +584,30 @@ impl FabricMgr { // Parameters to T are the Fabric and its Fabric Index pub fn for_each(&self, mut f: T) -> Result<(), Error> where - T: FnMut(&Fabric, u8), + T: FnMut(&Fabric, u8) -> Result<(), Error>, { - let mgr = self.inner.read().unwrap(); for i in 1..MAX_SUPPORTED_FABRICS { - if let Some(fabric) = &mgr.fabrics[i] { - f(fabric, i as u8) + if let Some(fabric) = &self.fabrics[i] { + f(fabric, i as u8)?; } } Ok(()) } - pub fn set_label(&self, index: u8, label: String) -> Result<(), Error> { + pub fn set_label(&mut self, index: u8, label: &str) -> Result<(), Error> { let index = index as usize; - let mut mgr = self.inner.write()?; if !label.is_empty() { for i in 1..MAX_SUPPORTED_FABRICS { - if let Some(fabric) = &mgr.fabrics[i] { + if let Some(fabric) = &self.fabrics[i] { if fabric.label == label { return Err(Error::Invalid); } } } } - if let Some(fabric) = &mut mgr.fabrics[index] { - let old = fabric.label.clone(); - fabric.label = label; - let psm = self.psm.lock().unwrap(); - if fabric.store(index, &psm).is_err() { - fabric.label = old; - return Err(Error::StdIoError); - } + if let Some(fabric) = &mut self.fabrics[index] { + fabric.label = label.into(); + self.changed = true; } Ok(()) } diff --git a/matter/src/group_keys.rs b/matter/src/group_keys.rs index 73c40e5b..c4dfaafa 100644 --- a/matter/src/group_keys.rs +++ b/matter/src/group_keys.rs @@ -15,10 +15,13 @@ * limitations under the License. */ -use std::sync::{Arc, Mutex, Once}; +use alloc::sync::Arc; +use std::sync::{Mutex, Once}; use crate::{crypto, error::Error}; +extern crate alloc; + // This is just makeshift implementation for now, not used anywhere pub struct GroupKeys {} diff --git a/matter/src/interaction_model/command.rs b/matter/src/interaction_model/command.rs deleted file mode 100644 index 323c0931..00000000 --- a/matter/src/interaction_model/command.rs +++ /dev/null @@ -1,88 +0,0 @@ -/* - * - * Copyright (c) 2020-2022 Project CHIP Authors - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -use super::core::IMStatusCode; -use super::core::OpCode; -use super::messages::ib; -use super::messages::msg; -use super::messages::msg::InvReq; -use super::InteractionModel; -use super::Transaction; -use crate::{ - error::*, - tlv::{get_root_node_struct, print_tlv_list, FromTLV, TLVElement, TLVWriter, TagType}, - transport::{packet::Packet, proto_demux::ResponseRequired}, -}; -use log::error; - -#[macro_export] -macro_rules! cmd_enter { - ($e:expr) => {{ - use colored::Colorize; - info! {"{} {}", "Handling Command".cyan(), $e.cyan()} - }}; -} - -pub struct CommandReq<'a, 'b, 'c, 'd, 'e> { - pub cmd: ib::CmdPath, - pub data: TLVElement<'a>, - pub resp: &'a mut TLVWriter<'b, 'c>, - pub trans: &'a mut Transaction<'d, 'e>, -} - -impl InteractionModel { - pub fn handle_invoke_req( - &mut self, - trans: &mut Transaction, - rx_buf: &[u8], - proto_tx: &mut Packet, - ) -> Result { - if InteractionModel::req_timeout_handled(trans, proto_tx)? { - return Ok(ResponseRequired::Yes); - } - - proto_tx.set_proto_opcode(OpCode::InvokeResponse as u8); - let mut tw = TLVWriter::new(proto_tx.get_writebuf()?); - let root = get_root_node_struct(rx_buf)?; - let inv_req = InvReq::from_tlv(&root)?; - - let timed_tx = trans.get_timeout().map(|_| true); - let timed_request = inv_req.timed_request.filter(|a| *a); - // Either both should be None, or both should be Some(true) - if timed_tx != timed_request { - InteractionModel::create_status_response(proto_tx, IMStatusCode::TimedRequestMisMatch)?; - return Ok(ResponseRequired::Yes); - } - - tw.start_struct(TagType::Anonymous)?; - // Suppress Response -> TODO: Need to revisit this for cases where we send a command back - tw.bool( - TagType::Context(msg::InvRespTag::SupressResponse as u8), - false, - )?; - - self.consumer - .consume_invoke_cmd(&inv_req, trans, &mut tw) - .map_err(|e| { - error!("Error in handling command: {:?}", e); - print_tlv_list(rx_buf); - e - })?; - tw.end_container()?; - Ok(ResponseRequired::Yes) - } -} diff --git a/matter/src/interaction_model/core.rs b/matter/src/interaction_model/core.rs index 1d548eb0..8d8b4fb4 100644 --- a/matter/src/interaction_model/core.rs +++ b/matter/src/interaction_model/core.rs @@ -15,34 +15,88 @@ * limitations under the License. */ -use std::time::{Duration, SystemTime}; +use core::time::Duration; use crate::{ + data_model::core::DataHandler, error::*, - interaction_model::messages::msg::StatusResp, - tlv::{self, get_root_node_struct, FromTLV, TLVElement, TLVWriter, TagType, ToTLV}, - transport::{ - exchange::Exchange, - packet::Packet, - proto_demux::{self, ProtoCtx, ResponseRequired}, - session::SessionHandle, - }, + tlv::{get_root_node_struct, print_tlv_list, FromTLV, TLVElement, TLVWriter, TagType, ToTLV}, + transport::{exchange::ExchangeCtx, packet::Packet, proto_ctx::ProtoCtx, session::Session}, }; use colored::Colorize; use log::{error, info}; use num; use num_derive::FromPrimitive; -use super::InteractionModel; -use super::Transaction; -use super::TransactionState; -use super::{messages::msg::TimedReq, InteractionConsumer}; +use super::messages::msg::{self, InvReq, ReadReq, StatusResp, TimedReq, WriteReq}; -/* Handle messages related to the Interation Model - */ +#[macro_export] +macro_rules! cmd_enter { + ($e:expr) => {{ + use colored::Colorize; + info! {"{} {}", "Handling Command".cyan(), $e.cyan()} + }}; +} -/* Interaction Model ID as per the Matter Spec */ -const PROTO_ID_INTERACTION_MODEL: usize = 0x01; +#[derive(FromPrimitive, Debug, Clone, Copy, PartialEq)] +pub enum IMStatusCode { + Success = 0, + Failure = 1, + InvalidSubscription = 0x7D, + UnsupportedAccess = 0x7E, + UnsupportedEndpoint = 0x7F, + InvalidAction = 0x80, + UnsupportedCommand = 0x81, + InvalidCommand = 0x85, + UnsupportedAttribute = 0x86, + ConstraintError = 0x87, + UnsupportedWrite = 0x88, + ResourceExhausted = 0x89, + NotFound = 0x8b, + UnreportableAttribute = 0x8c, + InvalidDataType = 0x8d, + UnsupportedRead = 0x8f, + DataVersionMismatch = 0x92, + Timeout = 0x94, + Busy = 0x9c, + UnsupportedCluster = 0xc3, + NoUpstreamSubscription = 0xc5, + NeedsTimedInteraction = 0xc6, + UnsupportedEvent = 0xc7, + PathsExhausted = 0xc8, + TimedRequestMisMatch = 0xc9, + FailSafeRequired = 0xca, +} + +impl From for IMStatusCode { + fn from(e: Error) -> Self { + match e { + Error::EndpointNotFound => IMStatusCode::UnsupportedEndpoint, + Error::ClusterNotFound => IMStatusCode::UnsupportedCluster, + Error::AttributeNotFound => IMStatusCode::UnsupportedAttribute, + Error::CommandNotFound => IMStatusCode::UnsupportedCommand, + Error::InvalidAction => IMStatusCode::InvalidAction, + Error::InvalidCommand => IMStatusCode::InvalidCommand, + Error::UnsupportedAccess => IMStatusCode::UnsupportedAccess, + Error::Busy => IMStatusCode::Busy, + Error::DataVersionMismatch => IMStatusCode::DataVersionMismatch, + Error::ResourceExhausted => IMStatusCode::ResourceExhausted, + _ => IMStatusCode::Failure, + } + } +} + +impl FromTLV<'_> for IMStatusCode { + fn from_tlv(t: &TLVElement) -> Result { + num::FromPrimitive::from_u16(t.u16()?).ok_or(Error::Invalid) + } +} + +impl ToTLV for IMStatusCode { + fn to_tlv(&self, tw: &mut TLVWriter, tag_type: TagType) -> Result<(), Error> { + tw.u16(tag_type, *self as u16) + } +} #[derive(FromPrimitive, Debug, Copy, Clone, PartialEq)] pub enum OpCode { @@ -59,15 +113,33 @@ pub enum OpCode { TimedRequest = 10, } +#[derive(PartialEq)] +pub enum TransactionState { + Ongoing, + Complete, + Terminate, +} +pub struct Transaction<'a, 'b> { + state: TransactionState, + ctx: &'a mut ExchangeCtx<'b>, +} + impl<'a, 'b> Transaction<'a, 'b> { - pub fn new(session: &'a mut SessionHandle<'b>, exch: &'a mut Exchange) -> Self { + pub fn new(ctx: &'a mut ExchangeCtx<'b>) -> Self { Self { state: TransactionState::Ongoing, - session, - exch, + ctx, } } + pub fn session(&self) -> &Session { + self.ctx.sess.session() + } + + pub fn session_mut(&mut self) -> &mut Session { + self.ctx.sess.session_mut() + } + /// Terminates the transaction, no communication (even ACKs) happens hence forth pub fn terminate(&mut self) { self.state = TransactionState::Terminate @@ -76,7 +148,6 @@ impl<'a, 'b> Transaction<'a, 'b> { pub fn is_terminate(&self) -> bool { self.state == TransactionState::Terminate } - /// Marks the transaction as completed from the application's perspective pub fn complete(&mut self) { self.state = TransactionState::Complete @@ -87,17 +158,20 @@ impl<'a, 'b> Transaction<'a, 'b> { } pub fn set_timeout(&mut self, timeout: u64) { - self.exch - .set_data_time(SystemTime::now().checked_add(Duration::from_millis(timeout))); + let now = (self.ctx.epoch)(); + + self.ctx + .exch + .set_data_time(now.checked_add(Duration::from_millis(timeout))); } - pub fn get_timeout(&mut self) -> Option { - self.exch.get_data_time() + pub fn get_timeout(&mut self) -> Option { + self.ctx.exch.get_data_time() } pub fn has_timed_out(&self) -> bool { - if let Some(timeout) = self.exch.get_data_time() { - if SystemTime::now() > timeout { + if let Some(timeout) = self.ctx.exch.get_data_time() { + if (self.ctx.epoch)() > timeout { return true; } } @@ -105,174 +179,336 @@ impl<'a, 'b> Transaction<'a, 'b> { } } -impl InteractionModel { - pub fn new(consumer: Box) -> InteractionModel { - InteractionModel { consumer } - } +/* Interaction Model ID as per the Matter Spec */ +const PROTO_ID_INTERACTION_MODEL: usize = 0x01; - pub fn handle_subscribe_req( - &mut self, - trans: &mut Transaction, - rx_buf: &[u8], - proto_tx: &mut Packet, - ) -> Result { - let mut tw = TLVWriter::new(proto_tx.get_writebuf()?); - let (opcode, resp) = self.consumer.consume_subscribe(rx_buf, trans, &mut tw)?; - proto_tx.set_proto_opcode(opcode as u8); - Ok(resp) - } +pub enum Interaction<'a> { + Read(ReadReq<'a>), + Write(WriteReq<'a>), + Invoke(InvReq<'a>), + Timed(TimedReq), +} - pub fn handle_status_resp( - &mut self, - trans: &mut Transaction, - rx_buf: &[u8], - proto_tx: &mut Packet, - ) -> Result { - let mut tw = TLVWriter::new(proto_tx.get_writebuf()?); - let root = get_root_node_struct(rx_buf)?; - let req = StatusResp::from_tlv(&root)?; - let (opcode, resp) = self.consumer.consume_status_report(&req, trans, &mut tw)?; - proto_tx.set_proto_opcode(opcode as u8); - Ok(resp) +impl<'a> Interaction<'a> { + pub fn new(rx: &'a Packet) -> Result { + let opcode: OpCode = + num::FromPrimitive::from_u8(rx.get_proto_opcode()).ok_or(Error::Invalid)?; + + let rx_data = rx.as_slice(); + + info!("{} {:?}", "Received command".cyan(), opcode); + print_tlv_list(rx_data); + + match opcode { + OpCode::ReadRequest => Ok(Self::Read(ReadReq::from_tlv(&get_root_node_struct( + rx_data, + )?)?)), + OpCode::WriteRequest => Ok(Self::Write(WriteReq::from_tlv(&get_root_node_struct( + rx_data, + )?)?)), + OpCode::InvokeRequest => Ok(Self::Invoke(InvReq::from_tlv(&get_root_node_struct( + rx_data, + )?)?)), + OpCode::TimedRequest => Ok(Self::Timed(TimedReq::from_tlv(&get_root_node_struct( + rx_data, + )?)?)), + // TODO + // OpCode::SubscribeRequest => self.handle_subscribe_req(&mut trans, buf, &mut ctx.tx)?, + // OpCode::StatusResponse => self.handle_status_resp(&mut trans, buf, &mut ctx.tx)?, + _ => { + error!("Opcode Not Handled: {:?}", opcode); + Err(Error::InvalidOpcode) + } + } } - pub fn handle_timed_req( - &mut self, - trans: &mut Transaction, - rx_buf: &[u8], - proto_tx: &mut Packet, - ) -> Result { - proto_tx.set_proto_opcode(OpCode::StatusResponse as u8); - - let root = get_root_node_struct(rx_buf)?; - let req = TimedReq::from_tlv(&root)?; - trans.set_timeout(req.timeout.into()); - - let status = StatusResp { - status: IMStatusCode::Success, + pub fn initiate_tx( + &self, + tx: &mut Packet, + transaction: &mut Transaction, + ) -> Result { + let reply = match self { + Self::Read(request) => { + tx.set_proto_id(PROTO_ID_INTERACTION_MODEL as u16); + tx.set_proto_opcode(OpCode::ReportData as u8); + + let mut tw = TLVWriter::new(tx.get_writebuf()?); + + tw.start_struct(TagType::Anonymous)?; + + if request.attr_requests.is_some() { + tw.start_array(TagType::Context(msg::ReportDataTag::AttributeReports as u8))?; + } + + false + } + Interaction::Write(_) => { + if transaction.has_timed_out() { + Self::create_status_response(tx, IMStatusCode::Timeout)?; + + transaction.complete(); + transaction.ctx.exch.close(); + + true + } else { + tx.set_proto_id(PROTO_ID_INTERACTION_MODEL as u16); + tx.set_proto_opcode(OpCode::WriteResponse as u8); + + let mut tw = TLVWriter::new(tx.get_writebuf()?); + + tw.start_struct(TagType::Anonymous)?; + tw.start_array(TagType::Context(msg::WriteRespTag::WriteResponses as u8))?; + + false + } + } + Interaction::Invoke(request) => { + if transaction.has_timed_out() { + Self::create_status_response(tx, IMStatusCode::Timeout)?; + + transaction.complete(); + transaction.ctx.exch.close(); + + true + } else { + let timed_tx = transaction.get_timeout().map(|_| true); + let timed_request = request.timed_request.filter(|a| *a); + + // Either both should be None, or both should be Some(true) + if timed_tx != timed_request { + Self::create_status_response(tx, IMStatusCode::TimedRequestMisMatch)?; + + true + } else { + tx.set_proto_id(PROTO_ID_INTERACTION_MODEL as u16); + tx.set_proto_opcode(OpCode::InvokeResponse as u8); + + let mut tw = TLVWriter::new(tx.get_writebuf()?); + + tw.start_struct(TagType::Anonymous)?; + + // Suppress Response -> TODO: Need to revisit this for cases where we send a command back + tw.bool( + TagType::Context(msg::InvRespTag::SupressResponse as u8), + false, + )?; + + if request.inv_requests.is_some() { + tw.start_array(TagType::Context( + msg::InvRespTag::InvokeResponses as u8, + ))?; + } + + false + } + } + } + Interaction::Timed(request) => { + tx.set_proto_id(PROTO_ID_INTERACTION_MODEL as u16); + tx.set_proto_opcode(OpCode::StatusResponse as u8); + + let mut tw = TLVWriter::new(tx.get_writebuf()?); + + transaction.set_timeout(request.timeout.into()); + + let status = StatusResp { + status: IMStatusCode::Success, + }; + + status.to_tlv(&mut tw, TagType::Anonymous)?; + + true + } }; - let mut tw = TLVWriter::new(proto_tx.get_writebuf()?); - let _ = status.to_tlv(&mut tw, TagType::Anonymous); - Ok(ResponseRequired::Yes) + + Ok(!reply) } - /// Handle Request Timeouts - /// This API checks if a request was a timed request, and if so, and if the timeout has - /// expired, it will generate the appropriate response as expected - pub(super) fn req_timeout_handled( - trans: &mut Transaction, - proto_tx: &mut Packet, + pub fn complete_tx( + &self, + tx: &mut Packet, + transaction: &mut Transaction, ) -> Result { - if trans.has_timed_out() { - trans.complete(); - InteractionModel::create_status_response(proto_tx, IMStatusCode::Timeout)?; - Ok(true) - } else { - Ok(false) - } - } + let reply = match self { + Self::Read(request) => { + let mut tw = TLVWriter::new(tx.get_writebuf()?); - pub(super) fn create_status_response( - proto_tx: &mut Packet, - status: IMStatusCode, - ) -> Result<(), Error> { - proto_tx.set_proto_opcode(OpCode::StatusResponse as u8); - let mut tw = TLVWriter::new(proto_tx.get_writebuf()?); - let status = StatusResp { status }; - status.to_tlv(&mut tw, TagType::Anonymous) - } -} + if request.attr_requests.is_some() { + tw.end_container()?; + } -impl proto_demux::HandleProto for InteractionModel { - fn handle_proto_id(&mut self, ctx: &mut ProtoCtx) -> Result { - let mut trans = Transaction::new(&mut ctx.exch_ctx.sess, ctx.exch_ctx.exch); - let proto_opcode: OpCode = - num::FromPrimitive::from_u8(ctx.rx.get_proto_opcode()).ok_or(Error::Invalid)?; - ctx.tx.set_proto_id(PROTO_ID_INTERACTION_MODEL as u16); - - let buf = ctx.rx.as_borrow_slice(); - info!("{} {:?}", "Received command".cyan(), proto_opcode); - tlv::print_tlv_list(buf); - let result = match proto_opcode { - OpCode::InvokeRequest => self.handle_invoke_req(&mut trans, buf, &mut ctx.tx)?, - OpCode::ReadRequest => self.handle_read_req(&mut trans, buf, &mut ctx.tx)?, - OpCode::WriteRequest => self.handle_write_req(&mut trans, buf, &mut ctx.tx)?, - OpCode::TimedRequest => self.handle_timed_req(&mut trans, buf, &mut ctx.tx)?, - OpCode::SubscribeRequest => self.handle_subscribe_req(&mut trans, buf, &mut ctx.tx)?, - OpCode::StatusResponse => self.handle_status_resp(&mut trans, buf, &mut ctx.tx)?, - _ => { - error!("Opcode Not Handled: {:?}", proto_opcode); - return Err(Error::InvalidOpcode); + // Suppress response always true for read interaction + tw.bool( + TagType::Context(msg::ReportDataTag::SupressResponse as u8), + true, + )?; + + tw.end_container()?; + + transaction.complete(); + + true } + Self::Write(request) => { + let suppress = request.supress_response.unwrap_or_default(); + + let mut tw = TLVWriter::new(tx.get_writebuf()?); + + tw.end_container()?; + tw.end_container()?; + + transaction.complete(); + + if suppress { + error!("Supress response is set, is this the expected handling?"); + false + } else { + true + } + } + Self::Invoke(request) => { + let mut tw = TLVWriter::new(tx.get_writebuf()?); + + if request.inv_requests.is_some() { + tw.end_container()?; + } + + tw.end_container()?; + + true + } + Self::Timed(_) => false, }; - if result == ResponseRequired::Yes { + if reply { info!("Sending response"); - tlv::print_tlv_list(ctx.tx.as_borrow_slice()); + print_tlv_list(tx.as_slice()); } - if trans.is_terminate() { - ctx.exch_ctx.exch.terminate(); - } else if trans.is_complete() { - ctx.exch_ctx.exch.close(); + + if transaction.is_terminate() { + transaction.ctx.exch.terminate(); + } else if transaction.is_complete() { + transaction.ctx.exch.close(); } - Ok(result) + + Ok(true) } - fn get_proto_id(&self) -> usize { - PROTO_ID_INTERACTION_MODEL + fn create_status_response(tx: &mut Packet, status: IMStatusCode) -> Result<(), Error> { + tx.set_proto_id(PROTO_ID_INTERACTION_MODEL as u16); + tx.set_proto_opcode(OpCode::StatusResponse as u8); + + let mut tw = TLVWriter::new(tx.get_writebuf()?); + + let status = StatusResp { status }; + status.to_tlv(&mut tw, TagType::Anonymous) } } -#[derive(FromPrimitive, Debug, Clone, Copy, PartialEq)] -pub enum IMStatusCode { - Success = 0, - Failure = 1, - InvalidSubscription = 0x7D, - UnsupportedAccess = 0x7E, - UnsupportedEndpoint = 0x7F, - InvalidAction = 0x80, - UnsupportedCommand = 0x81, - InvalidCommand = 0x85, - UnsupportedAttribute = 0x86, - ConstraintError = 0x87, - UnsupportedWrite = 0x88, - ResourceExhausted = 0x89, - NotFound = 0x8b, - UnreportableAttribute = 0x8c, - InvalidDataType = 0x8d, - UnsupportedRead = 0x8f, - DataVersionMismatch = 0x92, - Timeout = 0x94, - Busy = 0x9c, - UnsupportedCluster = 0xc3, - NoUpstreamSubscription = 0xc5, - NeedsTimedInteraction = 0xc6, - UnsupportedEvent = 0xc7, - PathsExhausted = 0xc8, - TimedRequestMisMatch = 0xc9, - FailSafeRequired = 0xca, +pub trait InteractionHandler { + fn handle<'a>(&mut self, ctx: &'a mut ProtoCtx) -> Result, Error>; } -impl From for IMStatusCode { - fn from(e: Error) -> Self { - match e { - Error::EndpointNotFound => IMStatusCode::UnsupportedEndpoint, - Error::ClusterNotFound => IMStatusCode::UnsupportedCluster, - Error::AttributeNotFound => IMStatusCode::UnsupportedAttribute, - Error::CommandNotFound => IMStatusCode::UnsupportedCommand, - _ => IMStatusCode::Failure, - } +impl InteractionHandler for &mut T +where + T: InteractionHandler, +{ + fn handle<'a>(&mut self, ctx: &'a mut ProtoCtx) -> Result, Error> { + (**self).handle(ctx) } } -impl FromTLV<'_> for IMStatusCode { - fn from_tlv(t: &TLVElement) -> Result { - num::FromPrimitive::from_u16(t.u16()?).ok_or(Error::Invalid) +pub struct InteractionModel(pub T); + +impl InteractionModel +where + T: DataHandler, +{ + pub fn handle<'a>(&mut self, ctx: &'a mut ProtoCtx) -> Result, Error> { + let interaction = Interaction::new(ctx.rx)?; + let mut transaction = Transaction::new(&mut ctx.exch_ctx); + + let reply = if interaction.initiate_tx(ctx.tx, &mut transaction)? { + self.0.handle(&interaction, ctx.tx, &mut transaction)?; + interaction.complete_tx(ctx.tx, &mut transaction)? + } else { + true + }; + + Ok(reply.then_some(ctx.tx.as_slice())) } } -impl ToTLV for IMStatusCode { - fn to_tlv(&self, tw: &mut TLVWriter, tag_type: TagType) -> Result<(), Error> { - tw.u16(tag_type, *self as u16) +#[cfg(feature = "nightly")] +impl InteractionModel +where + T: crate::data_model::core::asynch::AsyncDataHandler, +{ + pub async fn handle_async<'a>( + &mut self, + ctx: &'a mut ProtoCtx<'_, '_>, + ) -> Result, Error> { + let interaction = Interaction::new(ctx.rx)?; + let mut transaction = Transaction::new(&mut ctx.exch_ctx); + + let reply = if interaction.initiate_tx(ctx.tx, &mut transaction)? { + self.0 + .handle(&interaction, ctx.tx, &mut transaction) + .await?; + interaction.complete_tx(ctx.tx, &mut transaction)? + } else { + true + }; + + Ok(reply.then_some(ctx.tx.as_slice())) + } +} + +impl InteractionHandler for InteractionModel +where + T: DataHandler, +{ + fn handle<'a>(&mut self, ctx: &'a mut ProtoCtx) -> Result, Error> { + InteractionModel::handle(self, ctx) + } +} + +#[cfg(feature = "nightly")] +pub mod asynch { + use crate::{ + data_model::core::asynch::AsyncDataHandler, error::Error, transport::proto_ctx::ProtoCtx, + }; + + use super::InteractionModel; + + pub trait AsyncInteractionHandler { + async fn handle<'a>( + &mut self, + ctx: &'a mut ProtoCtx<'_, '_>, + ) -> Result, Error>; + } + + impl AsyncInteractionHandler for &mut T + where + T: AsyncInteractionHandler, + { + async fn handle<'a>( + &mut self, + ctx: &'a mut ProtoCtx<'_, '_>, + ) -> Result, Error> { + (**self).handle(ctx).await + } + } + + impl AsyncInteractionHandler for InteractionModel + where + T: AsyncDataHandler, + { + async fn handle<'a>( + &mut self, + ctx: &'a mut ProtoCtx<'_, '_>, + ) -> Result, Error> { + InteractionModel::handle_async(self, ctx).await + } } } diff --git a/matter/src/interaction_model/messages.rs b/matter/src/interaction_model/messages.rs index aac30f73..19e29f18 100644 --- a/matter/src/interaction_model/messages.rs +++ b/matter/src/interaction_model/messages.rs @@ -160,13 +160,6 @@ pub mod msg { pub inv_requests: Option>>, } - // This enum is helpful when we are constructing the response - // step by step in incremental manner - pub enum InvRespTag { - SupressResponse = 0, - InvokeResponses = 1, - } - #[derive(FromTLV, ToTLV, Debug)] #[tlvargs(lifetime = "'a")] pub struct InvResp<'a> { @@ -174,7 +167,14 @@ pub mod msg { pub inv_responses: Option>>, } - #[derive(Default, ToTLV, FromTLV)] + // This enum is helpful when we are constructing the response + // step by step in incremental manner + pub enum InvRespTag { + SupressResponse = 0, + InvokeResponses = 1, + } + + #[derive(Default, ToTLV, FromTLV, Debug)] #[tlvargs(lifetime = "'a")] pub struct ReadReq<'a> { pub attr_requests: Option>, @@ -198,17 +198,17 @@ pub mod msg { } } - #[derive(ToTLV, FromTLV)] - #[tlvargs(lifetime = "'b")] - pub struct WriteReq<'a, 'b> { + #[derive(FromTLV, ToTLV, Debug)] + #[tlvargs(lifetime = "'a")] + pub struct WriteReq<'a> { pub supress_response: Option, timed_request: Option, - pub write_requests: TLVArray<'a, AttrData<'b>>, + pub write_requests: TLVArray<'a, AttrData<'a>>, more_chunked: Option, } - impl<'a, 'b> WriteReq<'a, 'b> { - pub fn new(supress_response: bool, write_requests: &'a [AttrData<'b>]) -> Self { + impl<'a> WriteReq<'a> { + pub fn new(supress_response: bool, write_requests: &'a [AttrData<'a>]) -> Self { let mut w = Self { supress_response: None, write_requests: TLVArray::new(write_requests), @@ -223,7 +223,7 @@ pub mod msg { } // Report Data - #[derive(FromTLV, ToTLV)] + #[derive(FromTLV, ToTLV, Debug)] #[tlvargs(lifetime = "'a")] pub struct ReportDataMsg<'a> { pub subscription_id: Option, @@ -243,7 +243,7 @@ pub mod msg { } // Write Response - #[derive(ToTLV, FromTLV)] + #[derive(ToTLV, FromTLV, Debug)] #[tlvargs(lifetime = "'a")] pub struct WriteResp<'a> { pub write_responses: TLVArray<'a, AttrStatus>, @@ -255,10 +255,10 @@ pub mod msg { } pub mod ib { - use std::fmt::Debug; + use core::fmt::Debug; use crate::{ - data_model::objects::{AttrDetails, AttrId, ClusterId, EncodeValue, EndptId}, + data_model::objects::{AttrDetails, AttrId, ClusterId, CmdId, EncodeValue, EndptId}, error::Error, interaction_model::core::IMStatusCode, tlv::{FromTLV, Nullable, TLVElement, TLVWriter, TagType, ToTLV}, @@ -276,18 +276,6 @@ pub mod ib { } impl<'a> InvResp<'a> { - pub fn cmd_new( - endpoint: EndptId, - cluster: ClusterId, - cmd: u16, - data: EncodeValue<'a>, - ) -> Self { - Self::Cmd(CmdData::new( - CmdPath::new(Some(endpoint), Some(cluster), Some(cmd)), - data, - )) - } - pub fn status_new(cmd_path: CmdPath, status: IMStatusCode, cluster_status: u16) -> Self { Self::Status(CmdStatus { path: cmd_path, @@ -296,6 +284,23 @@ pub mod ib { } } + impl<'a> From> for InvResp<'a> { + fn from(value: CmdData<'a>) -> Self { + Self::Cmd(value) + } + } + + pub enum InvRespTag { + Cmd = 0, + Status = 1, + } + + impl<'a> From for InvResp<'a> { + fn from(value: CmdStatus) -> Self { + Self::Status(value) + } + } + #[derive(FromTLV, ToTLV, Copy, Clone, PartialEq, Debug)] pub struct CmdStatus { path: CmdPath, @@ -327,6 +332,11 @@ pub mod ib { } } + pub enum CmdDataTag { + Path = 0, + Data = 1, + } + // Status #[derive(Debug, Clone, Copy, PartialEq, FromTLV, ToTLV)] pub struct Status { @@ -352,10 +362,6 @@ pub mod ib { } impl<'a> AttrResp<'a> { - pub fn new(data_ver: u32, path: &AttrPath, data: EncodeValue<'a>) -> Self { - AttrResp::Data(AttrData::new(Some(data_ver), *path, data)) - } - pub fn unwrap_data(self) -> AttrData<'a> { match self { AttrResp::Data(d) => d, @@ -366,6 +372,23 @@ pub mod ib { } } + impl<'a> From> for AttrResp<'a> { + fn from(value: AttrData<'a>) -> Self { + Self::Data(value) + } + } + + impl<'a> From for AttrResp<'a> { + fn from(value: AttrStatus) -> Self { + Self::Status(value) + } + } + + pub enum AttrRespTag { + Status = 0, + Data = 1, + } + // Attribute Data #[derive(Clone, Copy, PartialEq, FromTLV, ToTLV, Debug)] #[tlvargs(lifetime = "'a")] @@ -385,6 +408,12 @@ pub mod ib { } } + pub enum AttrDataTag { + DataVer = 0, + Path = 1, + Data = 2, + } + #[derive(Debug)] /// Operations on an Interaction Model List pub enum ListOperation { @@ -399,13 +428,9 @@ pub mod ib { } /// Attribute Lists in Attribute Data are special. Infer the correct meaning using this function - pub fn attr_list_write( - attr: &AttrDetails, - data: &TLVElement, - mut f: F, - ) -> Result<(), IMStatusCode> + pub fn attr_list_write(attr: &AttrDetails, data: &TLVElement, mut f: F) -> Result<(), Error> where - F: FnMut(ListOperation, &TLVElement) -> Result<(), IMStatusCode>, + F: FnMut(ListOperation, &TLVElement) -> Result<(), Error>, { if let Some(Nullable::NotNull(index)) = attr.list_index { // If list index is valid, @@ -499,13 +524,13 @@ pub mod ib { pub fn new( endpoint: Option, cluster: Option, - command: Option, + command: Option, ) -> Self { Self { path: GenericPath { endpoint, cluster, - leaf: command.map(|a| a as u32), + leaf: command, }, } } @@ -532,20 +557,20 @@ pub mod ib { } } - #[derive(FromTLV, ToTLV, Copy, Clone)] + #[derive(FromTLV, ToTLV, Copy, Clone, Debug)] pub struct ClusterPath { pub node: Option, pub endpoint: EndptId, pub cluster: ClusterId, } - #[derive(FromTLV, ToTLV, Copy, Clone)] + #[derive(FromTLV, ToTLV, Copy, Clone, Debug)] pub struct DataVersionFilter { pub path: ClusterPath, pub data_ver: u32, } - #[derive(FromTLV, ToTLV, Copy, Clone)] + #[derive(FromTLV, ToTLV, Copy, Clone, Debug)] #[tlvargs(datatype = "list")] pub struct EventPath { pub node: Option, @@ -555,7 +580,7 @@ pub mod ib { pub is_urgent: Option, } - #[derive(FromTLV, ToTLV, Copy, Clone)] + #[derive(FromTLV, ToTLV, Copy, Clone, Debug)] pub struct EventFilter { pub node: Option, pub event_min: Option, diff --git a/matter/src/interaction_model/mod.rs b/matter/src/interaction_model/mod.rs index 2caf55f8..22e0ee96 100644 --- a/matter/src/interaction_model/mod.rs +++ b/matter/src/interaction_model/mod.rs @@ -15,73 +15,5 @@ * limitations under the License. */ -use crate::{ - error::Error, - tlv::TLVWriter, - transport::{exchange::Exchange, proto_demux::ResponseRequired, session::SessionHandle}, -}; - -use self::{ - core::OpCode, - messages::msg::{InvReq, StatusResp, WriteReq}, -}; - -#[derive(PartialEq)] -pub enum TransactionState { - Ongoing, - Complete, - Terminate, -} -pub struct Transaction<'a, 'b> { - pub state: TransactionState, - pub session: &'a mut SessionHandle<'b>, - pub exch: &'a mut Exchange, -} - -pub trait InteractionConsumer { - fn consume_invoke_cmd( - &self, - req: &InvReq, - trans: &mut Transaction, - tw: &mut TLVWriter, - ) -> Result<(), Error>; - - fn consume_read_attr( - &self, - // TODO: This handling is different from the other APIs here, identify - // consistent options for this trait - req: &[u8], - trans: &mut Transaction, - tw: &mut TLVWriter, - ) -> Result<(), Error>; - - fn consume_write_attr( - &self, - req: &WriteReq, - trans: &mut Transaction, - tw: &mut TLVWriter, - ) -> Result<(), Error>; - - fn consume_status_report( - &self, - _req: &StatusResp, - _trans: &mut Transaction, - _tw: &mut TLVWriter, - ) -> Result<(OpCode, ResponseRequired), Error>; - - fn consume_subscribe( - &self, - _req: &[u8], - _trans: &mut Transaction, - _tw: &mut TLVWriter, - ) -> Result<(OpCode, ResponseRequired), Error>; -} - -pub struct InteractionModel { - consumer: Box, -} -pub mod command; pub mod core; pub mod messages; -pub mod read; -pub mod write; diff --git a/matter/src/interaction_model/read.rs b/matter/src/interaction_model/read.rs deleted file mode 100644 index 0985eeae..00000000 --- a/matter/src/interaction_model/read.rs +++ /dev/null @@ -1,42 +0,0 @@ -/* - * - * Copyright (c) 2020-2022 Project CHIP Authors - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -use crate::{ - error::Error, - interaction_model::core::OpCode, - tlv::TLVWriter, - transport::{packet::Packet, proto_demux::ResponseRequired}, -}; - -use super::{InteractionModel, Transaction}; - -impl InteractionModel { - pub fn handle_read_req( - &mut self, - trans: &mut Transaction, - rx_buf: &[u8], - proto_tx: &mut Packet, - ) -> Result { - proto_tx.set_proto_opcode(OpCode::ReportData as u8); - let proto_tx_wb = proto_tx.get_writebuf()?; - let mut tw = TLVWriter::new(proto_tx_wb); - - self.consumer.consume_read_attr(rx_buf, trans, &mut tw)?; - - Ok(ResponseRequired::Yes) - } -} diff --git a/matter/src/interaction_model/write.rs b/matter/src/interaction_model/write.rs deleted file mode 100644 index 48b79036..00000000 --- a/matter/src/interaction_model/write.rs +++ /dev/null @@ -1,58 +0,0 @@ -/* - * - * Copyright (c) 2020-2022 Project CHIP Authors - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -use log::error; - -use crate::{ - error::Error, - tlv::{get_root_node_struct, FromTLV, TLVWriter, TagType}, - transport::{packet::Packet, proto_demux::ResponseRequired}, -}; - -use super::{core::OpCode, messages::msg::WriteReq, InteractionModel, Transaction}; - -impl InteractionModel { - pub fn handle_write_req( - &mut self, - trans: &mut Transaction, - rx_buf: &[u8], - proto_tx: &mut Packet, - ) -> Result { - if InteractionModel::req_timeout_handled(trans, proto_tx)? { - return Ok(ResponseRequired::Yes); - } - proto_tx.set_proto_opcode(OpCode::WriteResponse as u8); - - let mut tw = TLVWriter::new(proto_tx.get_writebuf()?); - let root = get_root_node_struct(rx_buf)?; - let write_req = WriteReq::from_tlv(&root)?; - let supress_response = write_req.supress_response.unwrap_or_default(); - - tw.start_struct(TagType::Anonymous)?; - self.consumer - .consume_write_attr(&write_req, trans, &mut tw)?; - tw.end_container()?; - - trans.complete(); - if supress_response { - error!("Supress response is set, is this the expected handling?"); - Ok(ResponseRequired::No) - } else { - Ok(ResponseRequired::Yes) - } - } -} diff --git a/matter/src/lib.rs b/matter/src/lib.rs index 9f03ac59..0d99cdb7 100644 --- a/matter/src/lib.rs +++ b/matter/src/lib.rs @@ -23,7 +23,7 @@ //! Currently Ethernet based transport is supported. //! //! # Examples -//! ``` +//! TODO: Fix once new API has stabilized a bit //! use matter::{Matter, CommissioningData}; //! use matter::data_model::device_types::device_type_add_on_off_light; //! use matter::data_model::cluster_basic_information::BasicInfoConfig; @@ -65,8 +65,11 @@ //! } //! // Start the Matter Daemon //! // matter.start_daemon().unwrap(); -//! ``` +//! //! Start off exploring by going to the [Matter] object. +#![cfg_attr(not(feature = "std"), no_std)] +#![cfg_attr(feature = "nightly", feature(async_fn_in_trait))] +#![cfg_attr(feature = "nightly", allow(incomplete_features))] pub mod acl; pub mod cert; @@ -80,6 +83,7 @@ pub mod group_keys; pub mod interaction_model; pub mod mdns; pub mod pairing; +pub mod persist; pub mod secure_channel; pub mod sys; pub mod tlv; diff --git a/matter/src/mdns.rs b/matter/src/mdns.rs index f28bea08..71be231f 100644 --- a/matter/src/mdns.rs +++ b/matter/src/mdns.rs @@ -15,34 +15,58 @@ * limitations under the License. */ -use std::sync::{Arc, Mutex, Once}; +use core::fmt::Write; -use crate::{ - error::Error, - sys::{sys_publish_service, SysMdnsService}, - transport::udp::MATTER_PORT, -}; +use crate::error::Error; -#[derive(Default)] -/// The mDNS service handler -pub struct MdnsInner { - /// Vendor ID - vid: u16, - /// Product ID - pid: u16, - /// Device name - device_name: String, +pub trait Mdns { + fn add( + &mut self, + name: &str, + service_type: &str, + port: u16, + txt_kvs: &[(&str, &str)], + ) -> Result<(), Error>; + + fn remove(&mut self, name: &str, service_type: &str, port: u16) -> Result<(), Error>; } -pub struct Mdns { - inner: Mutex, +impl Mdns for &mut T +where + T: Mdns, +{ + fn add( + &mut self, + name: &str, + service_type: &str, + port: u16, + txt_kvs: &[(&str, &str)], + ) -> Result<(), Error> { + (**self).add(name, service_type, port, txt_kvs) + } + + fn remove(&mut self, name: &str, service_type: &str, port: u16) -> Result<(), Error> { + (**self).remove(name, service_type, port) + } } -const SHORT_DISCRIMINATOR_MASK: u16 = 0xF00; -const SHORT_DISCRIMINATOR_SHIFT: u16 = 8; +pub struct DummyMdns; + +impl Mdns for DummyMdns { + fn add( + &mut self, + _name: &str, + _service_type: &str, + _port: u16, + _txt_kvs: &[(&str, &str)], + ) -> Result<(), Error> { + Ok(()) + } -static mut G_MDNS: Option> = None; -static INIT: Once = Once::new(); + fn remove(&mut self, _name: &str, _service_type: &str, _port: u16) -> Result<(), Error> { + Ok(()) + } +} pub enum ServiceMode { /// The commissioned state @@ -51,68 +75,108 @@ pub enum ServiceMode { Commissionable(u16), } -impl Mdns { - fn new() -> Self { - Self { - inner: Mutex::new(MdnsInner { - ..Default::default() - }), - } - } +/// The mDNS service handler +pub struct MdnsMgr<'a> { + /// Vendor ID + vid: u16, + /// Product ID + pid: u16, + /// Device name + device_name: heapless::String<32>, + /// Matter port + matter_port: u16, + /// mDns service + mdns: &'a mut dyn Mdns, +} - /// Get a handle to the globally unique mDNS instance - pub fn get() -> Result, Error> { - unsafe { - INIT.call_once(|| { - G_MDNS = Some(Arc::new(Mdns::new())); - }); - Ok(G_MDNS.as_ref().ok_or(Error::Invalid)?.clone()) +impl<'a> MdnsMgr<'a> { + pub fn new( + vid: u16, + pid: u16, + device_name: &str, + matter_port: u16, + mdns: &'a mut dyn Mdns, + ) -> Self { + Self { + vid, + pid, + device_name: device_name.chars().take(32).collect(), + matter_port, + mdns, } } - /// Set mDNS service specific values - /// Values like vid, pid, discriminator etc - // TODO: More things like device-type etc can be added here - pub fn set_values(&self, vid: u16, pid: u16, device_name: &str) { - let mut inner = self.inner.lock().unwrap(); - inner.vid = vid; - inner.pid = pid; - inner.device_name = device_name.chars().take(32).collect(); - } - - /// Publish a mDNS service + /// Publish an mDNS service /// name - is the service name (comma separated subtypes may follow) /// mode - the current service mode #[allow(clippy::needless_pass_by_value)] - pub fn publish_service(&self, name: &str, mode: ServiceMode) -> Result { + pub fn publish_service(&mut self, name: &str, mode: ServiceMode) -> Result<(), Error> { match mode { - ServiceMode::Commissioned => { - sys_publish_service(name, "_matter._tcp", MATTER_PORT, &[]) - } + ServiceMode::Commissioned => self.mdns.add(name, "_matter._tcp", self.matter_port, &[]), ServiceMode::Commissionable(discriminator) => { - let inner = self.inner.lock().unwrap(); - let short = compute_short_discriminator(discriminator); - let serv_type = format!("_matterc._udp,_S{},_L{}", short, discriminator); + let discriminator_str = Self::get_discriminator_str(discriminator); + + let serv_type = self.get_service_type(discriminator); + let vp = self.get_vp(); - let str_discriminator = format!("{}", discriminator); let txt_kvs = [ - ["D", &str_discriminator], - ["CM", "1"], - ["DN", &inner.device_name], - ["VP", &format!("{}+{}", inner.vid, inner.pid)], - ["SII", "5000"], /* Sleepy Idle Interval */ - ["SAI", "300"], /* Sleepy Active Interval */ - ["PH", "33"], /* Pairing Hint */ - ["PI", ""], /* Pairing Instruction */ + ("D", discriminator_str.as_str()), + ("CM", "1"), + ("DN", self.device_name.as_str()), + ("VP", &vp), + ("SII", "5000"), /* Sleepy Idle Interval */ + ("SAI", "300"), /* Sleepy Active Interval */ + ("PH", "33"), /* Pairing Hint */ + ("PI", ""), /* Pairing Instruction */ ]; - sys_publish_service(name, &serv_type, MATTER_PORT, &txt_kvs) + self.mdns.add(name, &serv_type, self.matter_port, &txt_kvs) + } + } + } + + pub fn unpublish_service(&mut self, name: &str, mode: ServiceMode) -> Result<(), Error> { + match mode { + ServiceMode::Commissioned => self.mdns.remove(name, "_matter._tcp", self.matter_port), + ServiceMode::Commissionable(discriminator) => { + let serv_type = self.get_service_type(discriminator); + + self.mdns.remove(name, &serv_type, self.matter_port) } } } -} -fn compute_short_discriminator(discriminator: u16) -> u16 { - (discriminator & SHORT_DISCRIMINATOR_MASK) >> SHORT_DISCRIMINATOR_SHIFT + fn get_service_type(&self, discriminator: u16) -> heapless::String<32> { + let short = Self::compute_short_discriminator(discriminator); + let mut serv_type = heapless::String::new(); + + write!( + &mut serv_type, + "_matterc._udp,_S{},_L{}", + short, discriminator + ) + .unwrap(); + + serv_type + } + + fn get_vp(&self) -> heapless::String<11> { + let mut vp = heapless::String::new(); + + write!(&mut vp, "{}+{}", self.vid, self.pid).unwrap(); + + vp + } + + fn get_discriminator_str(discriminator: u16) -> heapless::String<5> { + discriminator.into() + } + + fn compute_short_discriminator(discriminator: u16) -> u16 { + const SHORT_DISCRIMINATOR_MASK: u16 = 0xF00; + const SHORT_DISCRIMINATOR_SHIFT: u16 = 8; + + (discriminator & SHORT_DISCRIMINATOR_MASK) >> SHORT_DISCRIMINATOR_SHIFT + } } #[cfg(test)] @@ -122,11 +186,11 @@ mod tests { #[test] fn can_compute_short_discriminator() { let discriminator: u16 = 0b0000_1111_0000_0000; - let short = compute_short_discriminator(discriminator); + let short = MdnsMgr::compute_short_discriminator(discriminator); assert_eq!(short, 0b1111); let discriminator: u16 = 840; - let short = compute_short_discriminator(discriminator); + let short = MdnsMgr::compute_short_discriminator(discriminator); assert_eq!(short, 3); } } diff --git a/matter/src/pairing/code.rs b/matter/src/pairing/code.rs index 83d90f37..16e4feab 100644 --- a/matter/src/pairing/code.rs +++ b/matter/src/pairing/code.rs @@ -15,56 +15,66 @@ * limitations under the License. */ +use core::fmt::Write; + use super::*; -pub(super) fn compute_pairing_code(comm_data: &CommissioningData) -> String { +pub(super) fn compute_pairing_code(comm_data: &CommissioningData) -> heapless::String<32> { // 0: no Vendor ID and Product ID present in Manual Pairing Code const VID_PID_PRESENT: u8 = 0; let passwd = passwd_from_comm_data(comm_data); let CommissioningData { discriminator, .. } = comm_data; - let mut digits = String::new(); - digits.push_str(&((VID_PID_PRESENT << 2) | (discriminator >> 10) as u8).to_string()); - digits.push_str(&format!( - "{:0>5}", - ((discriminator & 0x300) << 6) | (passwd & 0x3FFF) as u16 - )); - digits.push_str(&format!("{:0>4}", passwd >> 14)); + let mut digits = heapless::String::<32>::new(); + write!( + &mut digits, + "{}{:0>5}{:0>4}", + (VID_PID_PRESENT << 2) | (discriminator >> 10) as u8, + ((discriminator & 0x300) << 6) | (passwd & 0x3FFF) as u16, + passwd >> 14 + ) + .unwrap(); - let check_digit = digits.calculate_verhoeff_check_digit(); - digits.push_str(&check_digit.to_string()); + let mut final_digits = heapless::String::<32>::new(); + write!( + &mut final_digits, + "{}{}", + digits, + digits.calculate_verhoeff_check_digit() + ) + .unwrap(); - digits + final_digits } pub(super) fn pretty_print_pairing_code(pairing_code: &str) { assert!(pairing_code.len() == 11); - let mut pretty = String::new(); - pretty.push_str(&pairing_code[..4]); - pretty.push('-'); - pretty.push_str(&pairing_code[4..8]); - pretty.push('-'); - pretty.push_str(&pairing_code[8..]); + let mut pretty = heapless::String::<32>::new(); + pretty.push_str(&pairing_code[..4]).unwrap(); + pretty.push('-').unwrap(); + pretty.push_str(&pairing_code[4..8]).unwrap(); + pretty.push('-').unwrap(); + pretty.push_str(&pairing_code[8..]).unwrap(); info!("Pairing Code: {}", pretty); } #[cfg(test)] mod tests { use super::*; - use crate::secure_channel::spake2p::VerifierData; + use crate::{secure_channel::spake2p::VerifierData, utils::rand::dummy_rand}; #[test] fn can_compute_pairing_code() { let comm_data = CommissioningData { - verifier: VerifierData::new_with_pw(123456), + verifier: VerifierData::new_with_pw(123456, dummy_rand), discriminator: 250, }; let pairing_code = compute_pairing_code(&comm_data); assert_eq!(pairing_code, "00876800071"); let comm_data = CommissioningData { - verifier: VerifierData::new_with_pw(34567890), + verifier: VerifierData::new_with_pw(34567890, dummy_rand), discriminator: 2976, }; let pairing_code = compute_pairing_code(&comm_data); diff --git a/matter/src/pairing/qr.rs b/matter/src/pairing/qr.rs index 0a3509db..f1d844a5 100644 --- a/matter/src/pairing/qr.rs +++ b/matter/src/pairing/qr.rs @@ -15,7 +15,7 @@ * limitations under the License. */ -use std::collections::BTreeMap; +use heapless::FnvIndexMap; use crate::{ tlv::{TLVWriter, TagType}, @@ -55,7 +55,7 @@ const SERIAL_NUMBER_TAG: u8 = 0x00; // const COMMISSIONING_TIMEOUT_TAG: u8 = 0x04; pub enum QRCodeInfoType { - String(String), + String(heapless::String<128>), // TODO: Big enough? Int32(i32), Int64(i64), UInt32(u32), @@ -63,7 +63,7 @@ pub enum QRCodeInfoType { } pub enum SerialNumber { - String(String), + String(heapless::String<128>), UInt32(u32), } @@ -78,10 +78,10 @@ pub struct QrSetupPayload<'data> { version: u8, flow_type: CommissionningFlowType, discovery_capabilities: DiscoveryCapabilities, - dev_det: &'data BasicInfoConfig, + dev_det: &'data BasicInfoConfig<'data>, comm_data: &'data CommissioningData, // we use a BTreeMap to keep the order of the optional data stable - optional_data: BTreeMap, + optional_data: heapless::FnvIndexMap, } impl<'data> QrSetupPayload<'data> { @@ -98,11 +98,11 @@ impl<'data> QrSetupPayload<'data> { discovery_capabilities, dev_det, comm_data, - optional_data: BTreeMap::new(), + optional_data: FnvIndexMap::new(), }; if !dev_det.serial_no.is_empty() { - result.add_serial_number(SerialNumber::String(dev_det.serial_no.clone())); + result.add_serial_number(SerialNumber::String(dev_det.serial_no.into())); } result @@ -137,7 +137,9 @@ impl<'data> QrSetupPayload<'data> { } self.optional_data - .insert(tag, OptionalQRCodeInfo { tag, data }); + .insert(tag, OptionalQRCodeInfo { tag, data }) + .map_err(|_| Error::NoSpace)?; + Ok(()) } @@ -155,11 +157,13 @@ impl<'data> QrSetupPayload<'data> { } self.optional_data - .insert(tag, OptionalQRCodeInfo { tag, data }); + .insert(tag, OptionalQRCodeInfo { tag, data }) + .map_err(|_| Error::NoSpace)?; + Ok(()) } - pub fn get_all_optional_data(&self) -> &BTreeMap { + pub fn get_all_optional_data(&self) -> &FnvIndexMap { &self.optional_data } @@ -388,7 +392,7 @@ fn generate_tlv_from_optional_data( ) -> Result<(), Error> { let size_needed = tlv_data.max_data_length_in_bytes as usize; let mut tlv_buffer = vec![0u8; size_needed]; - let mut wb = WriteBuf::new(&mut tlv_buffer, size_needed); + let mut wb = WriteBuf::new(&mut tlv_buffer); let mut tw = TLVWriter::new(&mut wb); tw.start_struct(TagType::Anonymous)?; @@ -532,16 +536,15 @@ fn is_common_tag(tag: u8) -> bool { #[cfg(test)] mod tests { - use super::*; - use crate::secure_channel::spake2p::VerifierData; + use crate::{secure_channel::spake2p::VerifierData, utils::rand::dummy_rand}; #[test] fn can_base38_encode() { const QR_CODE: &str = "MT:YNJV7VSC00CMVH7SR00"; let comm_data = CommissioningData { - verifier: VerifierData::new_with_pw(34567890), + verifier: VerifierData::new_with_pw(34567890, dummy_rand), discriminator: 2976, }; let dev_det = BasicInfoConfig { @@ -561,13 +564,13 @@ mod tests { const QR_CODE: &str = "MT:-24J0AFN00KA064IJ3P0IXZB0DK5N1K8SQ1RYCU1-A40"; let comm_data = CommissioningData { - verifier: VerifierData::new_with_pw(20202021), + verifier: VerifierData::new_with_pw(20202021, dummy_rand), discriminator: 3840, }; let dev_det = BasicInfoConfig { vid: 65521, pid: 32769, - serial_no: "1234567890".to_string(), + serial_no: "1234567890", ..Default::default() }; @@ -588,13 +591,13 @@ mod tests { const OPTIONAL_DEFAULT_INT_VALUE: i32 = 65550; let comm_data = CommissioningData { - verifier: VerifierData::new_with_pw(20202021), + verifier: VerifierData::new_with_pw(20202021, dummy_rand), discriminator: 3840, }; let dev_det = BasicInfoConfig { vid: 65521, pid: 32769, - serial_no: "1234567890".to_string(), + serial_no: "1234567890", ..Default::default() }; @@ -604,7 +607,7 @@ mod tests { qr_code_data .add_optional_vendor_data( OPTIONAL_DEFAULT_STRING_TAG, - QRCodeInfoType::String(OPTIONAL_DEFAULT_STRING_VALUE.to_string()), + QRCodeInfoType::String(OPTIONAL_DEFAULT_STRING_VALUE.into()), ) .expect("Failed to add optional data"); diff --git a/matter/src/persist.rs b/matter/src/persist.rs new file mode 100644 index 00000000..4bc8e241 --- /dev/null +++ b/matter/src/persist.rs @@ -0,0 +1,229 @@ +/* + * + * Copyright (c) 2020-2022 Project CHIP Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +use crate::error::Error; + +pub trait Psm { + fn set_kv_slice(&mut self, key: &str, val: &[u8]) -> Result<(), Error>; + fn get_kv_slice<'a>(&self, key: &str, buf: &'a mut [u8]) -> Result<&'a [u8], Error>; + + fn set_kv_u64(&mut self, key: &str, val: u64) -> Result<(), Error>; + fn get_kv_u64(&self, key: &str) -> Result; + + fn remove(&mut self, key: &str) -> Result<(), Error>; +} + +impl Psm for &mut T +where + T: Psm, +{ + fn set_kv_slice(&mut self, key: &str, val: &[u8]) -> Result<(), Error> { + (**self).set_kv_slice(key, val) + } + + fn get_kv_slice<'a>(&self, key: &str, buf: &'a mut [u8]) -> Result<&'a [u8], Error> { + (**self).get_kv_slice(key, buf) + } + + fn set_kv_u64(&mut self, key: &str, val: u64) -> Result<(), Error> { + (**self).set_kv_u64(key, val) + } + + fn get_kv_u64(&self, key: &str) -> Result { + (**self).get_kv_u64(key) + } + + fn remove(&mut self, key: &str) -> Result<(), Error> { + (**self).remove(key) + } +} + +#[cfg(feature = "nightly")] +pub mod asynch { + use crate::error::Error; + + use super::Psm; + + pub trait AsyncPsm { + async fn set_kv_slice<'a>(&'a mut self, key: &'a str, val: &'a [u8]) -> Result<(), Error>; + async fn get_kv_slice<'a, 'b>( + &'a self, + key: &'a str, + buf: &'b mut [u8], + ) -> Result<&'b [u8], Error>; + + async fn set_kv_u64<'a>(&'a mut self, key: &'a str, val: u64) -> Result<(), Error>; + async fn get_kv_u64<'a>(&'a self, key: &'a str) -> Result; + + async fn remove<'a>(&'a mut self, key: &'a str) -> Result<(), Error>; + } + + impl AsyncPsm for &mut T + where + T: AsyncPsm, + { + async fn set_kv_slice<'a>(&'a mut self, key: &'a str, val: &'a [u8]) -> Result<(), Error> { + (**self).set_kv_slice(key, val).await + } + + async fn get_kv_slice<'a, 'b>( + &'a self, + key: &'a str, + buf: &'b mut [u8], + ) -> Result<&'b [u8], Error> { + (**self).get_kv_slice(key, buf).await + } + + async fn set_kv_u64<'a>(&'a mut self, key: &'a str, val: u64) -> Result<(), Error> { + (**self).set_kv_u64(key, val).await + } + + async fn get_kv_u64<'a>(&'a self, key: &'a str) -> Result { + (**self).get_kv_u64(key).await + } + + async fn remove<'a>(&'a mut self, key: &'a str) -> Result<(), Error> { + (**self).remove(key).await + } + } + + pub struct Asyncify(pub T); + + impl AsyncPsm for Asyncify + where + T: Psm, + { + async fn set_kv_slice<'a>(&'a mut self, key: &'a str, val: &'a [u8]) -> Result<(), Error> { + self.0.set_kv_slice(key, val) + } + + async fn get_kv_slice<'a, 'b>( + &'a self, + key: &'a str, + buf: &'b mut [u8], + ) -> Result<&'b [u8], Error> { + self.0.get_kv_slice(key, buf) + } + + async fn set_kv_u64<'a>(&'a mut self, key: &'a str, val: u64) -> Result<(), Error> { + self.0.set_kv_u64(key, val) + } + + async fn get_kv_u64<'a>(&'a self, key: &'a str) -> Result { + self.0.get_kv_u64(key) + } + + async fn remove<'a>(&'a mut self, key: &'a str) -> Result<(), Error> { + self.0.remove(key) + } + } +} + +#[cfg(feature = "std")] +pub mod std { + use std::fs::{self, DirBuilder, File}; + use std::io::{Read, Write}; + + use crate::error::Error; + + use super::Psm; + + pub struct FilePsm {} + + const PSM_DIR: &str = "/tmp/matter_psm"; + + macro_rules! psm_path { + ($key:ident) => { + format!("{}/{}", PSM_DIR, $key) + }; + } + + impl FilePsm { + pub fn new() -> Result { + let result = DirBuilder::new().create(PSM_DIR); + if let Err(e) = result { + if e.kind() != std::io::ErrorKind::AlreadyExists { + return Err(e.into()); + } + } + + Ok(Self {}) + } + + pub fn set_kv_slice(&mut self, key: &str, val: &[u8]) -> Result<(), Error> { + let mut f = File::create(psm_path!(key))?; + f.write_all(val)?; + Ok(()) + } + + pub fn get_kv_slice<'a>(&self, key: &str, buf: &'a mut [u8]) -> Result<&'a [u8], Error> { + let mut f = File::open(psm_path!(key))?; + let mut offset = 0; + + loop { + let len = f.read(&mut buf[offset..])?; + offset += len; + + if len == 0 { + break; + } + } + + Ok(&buf[..offset]) + } + + pub fn set_kv_u64(&mut self, key: &str, val: u64) -> Result<(), Error> { + let mut f = File::create(psm_path!(key))?; + f.write_all(&val.to_be_bytes())?; + Ok(()) + } + + pub fn get_kv_u64(&self, key: &str) -> Result { + let mut f = File::open(psm_path!(key))?; + let mut buf = [0; 8]; + f.read_exact(&mut buf)?; + Ok(u64::from_be_bytes(buf)) + } + + pub fn remove(&self, key: &str) -> Result<(), Error> { + fs::remove_file(psm_path!(key))?; + Ok(()) + } + } + + impl Psm for FilePsm { + fn set_kv_slice(&mut self, key: &str, val: &[u8]) -> Result<(), Error> { + FilePsm::set_kv_slice(self, key, val) + } + + fn get_kv_slice<'a>(&self, key: &str, buf: &'a mut [u8]) -> Result<&'a [u8], Error> { + FilePsm::get_kv_slice(self, key, buf) + } + + fn set_kv_u64(&mut self, key: &str, val: u64) -> Result<(), Error> { + FilePsm::set_kv_u64(self, key, val) + } + + fn get_kv_u64(&self, key: &str) -> Result { + FilePsm::get_kv_u64(self, key) + } + + fn remove(&mut self, key: &str) -> Result<(), Error> { + FilePsm::remove(self, key) + } + } +} diff --git a/matter/src/secure_channel/case.rs b/matter/src/secure_channel/case.rs index 58a5593f..80802ede 100644 --- a/matter/src/secure_channel/case.rs +++ b/matter/src/secure_channel/case.rs @@ -15,27 +15,25 @@ * limitations under the License. */ -use std::sync::Arc; +use core::cell::RefCell; use log::{error, trace}; -use owning_ref::RwLockReadGuardRef; -use rand::prelude::*; use crate::{ cert::Cert, - crypto::{self, CryptoKeyPair, KeyPair, Sha256}, + crypto::{self, KeyPair, Sha256}, error::Error, - fabric::{Fabric, FabricMgr, FabricMgrInner}, + fabric::{Fabric, FabricMgr}, secure_channel::common::SCStatusCodes, secure_channel::common::{self, OpCode}, tlv::{get_root_node_struct, FromTLV, OctetStr, TLVElement, TLVWriter, TagType}, transport::{ network::Address, - proto_demux::{ProtoCtx, ResponseRequired}, + proto_ctx::ProtoCtx, queue::{Msg, WorkQ}, session::{CaseDetails, CloneData, NocCatIds, SessionMode}, }, - utils::writebuf::WriteBuf, + utils::{rand::Rand, writebuf::WriteBuf}, }; #[derive(PartialEq)] @@ -54,6 +52,7 @@ pub struct CaseSession { peer_pub_key: [u8; crypto::EC_POINT_LEN_BYTES], local_fabric_idx: usize, } + impl CaseSession { pub fn new(peer_sessid: u16, local_sessid: u16) -> Result { Ok(Self { @@ -69,40 +68,43 @@ impl CaseSession { } } -pub struct Case { - fabric_mgr: Arc, +pub struct Case<'a> { + fabric_mgr: &'a RefCell, + rand: Rand, } -impl Case { - pub fn new(fabric_mgr: Arc) -> Self { - Self { fabric_mgr } +impl<'a> Case<'a> { + pub fn new(fabric_mgr: &'a RefCell, rand: Rand) -> Self { + Self { fabric_mgr, rand } } - pub fn casesigma3_handler(&mut self, ctx: &mut ProtoCtx) -> Result { + pub fn casesigma3_handler(&mut self, ctx: &mut ProtoCtx) -> Result { let mut case_session = ctx .exch_ctx .exch - .take_data_boxed::() + .take_case_session::() .ok_or(Error::InvalidState)?; if case_session.state != State::Sigma1Rx { return Err(Error::Invalid); } case_session.state = State::Sigma3Rx; - let fabric = self.fabric_mgr.get_fabric(case_session.local_fabric_idx)?; + let fabric_mgr = self.fabric_mgr.borrow(); + + let fabric = fabric_mgr.get_fabric(case_session.local_fabric_idx)?; if fabric.is_none() { common::create_sc_status_report( - &mut ctx.tx, + ctx.tx, common::SCStatusCodes::NoSharedTrustRoots, None, )?; ctx.exch_ctx.exch.close(); - return Ok(ResponseRequired::Yes); + return Ok(true); } // Safe to unwrap here - let fabric = fabric.as_ref().as_ref().unwrap(); + let fabric = fabric.unwrap(); - let root = get_root_node_struct(ctx.rx.as_borrow_slice())?; + let root = get_root_node_struct(ctx.rx.as_slice())?; let encrypted = root.find_tag(1)?.slice()?; let mut decrypted: [u8; 800] = [0; 800]; @@ -126,13 +128,9 @@ impl Case { } if let Err(e) = Case::validate_certs(fabric, &initiator_noc, &initiator_icac) { error!("Certificate Chain doesn't match: {}", e); - common::create_sc_status_report( - &mut ctx.tx, - common::SCStatusCodes::InvalidParameter, - None, - )?; + common::create_sc_status_report(ctx.tx, common::SCStatusCodes::InvalidParameter, None)?; ctx.exch_ctx.exch.close(); - return Ok(ResponseRequired::Yes); + return Ok(true); } if Case::validate_sigma3_sign( @@ -145,19 +143,15 @@ impl Case { .is_err() { error!("Sigma3 Signature doesn't match"); - common::create_sc_status_report( - &mut ctx.tx, - common::SCStatusCodes::InvalidParameter, - None, - )?; + common::create_sc_status_report(ctx.tx, common::SCStatusCodes::InvalidParameter, None)?; ctx.exch_ctx.exch.close(); - return Ok(ResponseRequired::Yes); + return Ok(true); } // Only now do we add this message to the TT Hash let mut peer_catids: NocCatIds = Default::default(); initiator_noc.get_cat_ids(&mut peer_catids); - case_session.tt_hash.update(ctx.rx.as_borrow_slice())?; + case_session.tt_hash.update(ctx.rx.as_slice())?; let clone_data = Case::get_session_clone_data( fabric.ipk.op_key(), fabric.get_node_id(), @@ -169,40 +163,36 @@ impl Case { // Queue a transport mgr request to add a new session WorkQ::get()?.sync_send(Msg::NewSession(clone_data))?; - common::create_sc_status_report( - &mut ctx.tx, - SCStatusCodes::SessionEstablishmentSuccess, - None, - )?; - ctx.exch_ctx.exch.clear_data_boxed(); + common::create_sc_status_report(ctx.tx, SCStatusCodes::SessionEstablishmentSuccess, None)?; + ctx.exch_ctx.exch.clear_data(); ctx.exch_ctx.exch.close(); - - Ok(ResponseRequired::Yes) + Ok(true) } - pub fn casesigma1_handler(&mut self, ctx: &mut ProtoCtx) -> Result { + pub fn casesigma1_handler(&mut self, ctx: &mut ProtoCtx) -> Result { ctx.tx.set_proto_opcode(OpCode::CASESigma2 as u8); - let rx_buf = ctx.rx.as_borrow_slice(); + let rx_buf = ctx.rx.as_slice(); let root = get_root_node_struct(rx_buf)?; let r = Sigma1Req::from_tlv(&root)?; let local_fabric_idx = self .fabric_mgr + .borrow_mut() .match_dest_id(r.initiator_random.0, r.dest_id.0); if local_fabric_idx.is_err() { error!("Fabric Index mismatch"); common::create_sc_status_report( - &mut ctx.tx, + ctx.tx, common::SCStatusCodes::NoSharedTrustRoots, None, )?; ctx.exch_ctx.exch.close(); - return Ok(ResponseRequired::Yes); + return Ok(true); } let local_sessid = ctx.exch_ctx.sess.reserve_new_sess_id(); - let mut case_session = Box::new(CaseSession::new(r.initiator_sessid, local_sessid)?); + let mut case_session = CaseSession::new(r.initiator_sessid, local_sessid)?; case_session.tt_hash.update(rx_buf)?; case_session.local_fabric_idx = local_fabric_idx?; if r.peer_pub_key.0.len() != crypto::EC_POINT_LEN_BYTES { @@ -228,7 +218,7 @@ impl Case { // println!("Derived secret: {:x?} len: {}", secret, len); let mut our_random: [u8; 32] = [0; 32]; - rand::thread_rng().fill_bytes(&mut our_random); + (self.rand)(&mut our_random); // Derive the Encrypted Part const MAX_ENCRYPTED_SIZE: usize = 800; @@ -236,19 +226,21 @@ impl Case { let mut encrypted: [u8; MAX_ENCRYPTED_SIZE] = [0; MAX_ENCRYPTED_SIZE]; let encrypted_len = { let mut signature = [0u8; crypto::EC_SIGNATURE_LEN_BYTES]; - let fabric = self.fabric_mgr.get_fabric(case_session.local_fabric_idx)?; + let fabric_mgr = self.fabric_mgr.borrow(); + + let fabric = fabric_mgr.get_fabric(case_session.local_fabric_idx)?; if fabric.is_none() { common::create_sc_status_report( - &mut ctx.tx, + ctx.tx, common::SCStatusCodes::NoSharedTrustRoots, None, )?; ctx.exch_ctx.exch.close(); - return Ok(ResponseRequired::Yes); + return Ok(true); } let sign_len = Case::get_sigma2_sign( - &fabric, + fabric.unwrap(), &case_session.our_pub_key, &case_session.peer_pub_key, &mut signature, @@ -256,7 +248,8 @@ impl Case { let signature = &signature[..sign_len]; Case::get_sigma2_encryption( - &fabric, + fabric.unwrap(), + self.rand, &our_random, &mut case_session, signature, @@ -273,9 +266,9 @@ impl Case { tw.str8(TagType::Context(3), &case_session.our_pub_key)?; tw.str16(TagType::Context(4), encrypted)?; tw.end_container()?; - case_session.tt_hash.update(ctx.tx.as_borrow_slice())?; - ctx.exch_ctx.exch.set_data_boxed(case_session); - Ok(ResponseRequired::Yes) + case_session.tt_hash.update(ctx.tx.as_mut_slice())?; + ctx.exch_ctx.exch.set_case_session(case_session); + Ok(true) } fn get_session_clone_data( @@ -322,8 +315,8 @@ impl Case { case_session: &CaseSession, ) -> Result<(), Error> { const MAX_TBS_SIZE: usize = 800; - let mut buf: [u8; MAX_TBS_SIZE] = [0; MAX_TBS_SIZE]; - let mut write_buf = WriteBuf::new(&mut buf, MAX_TBS_SIZE); + let mut buf = [0; MAX_TBS_SIZE]; + let mut write_buf = WriteBuf::new(&mut buf); let mut tw = TLVWriter::new(&mut write_buf); tw.start_struct(TagType::Anonymous)?; tw.str16(TagType::Context(1), initiator_noc)?; @@ -335,7 +328,7 @@ impl Case { tw.end_container()?; let key = KeyPair::new_from_public(initiator_noc_cert.get_pubkey())?; - key.verify_msg(write_buf.as_slice(), sign)?; + key.verify_msg(write_buf.into_slice(), sign)?; Ok(()) } @@ -372,12 +365,12 @@ impl Case { if key.len() < 48 { return Err(Error::NoSpace); } - let mut salt = Vec::::with_capacity(256); - salt.extend_from_slice(ipk); + let mut salt = heapless::Vec::::new(); + salt.extend_from_slice(ipk).unwrap(); let tt = tt.clone(); let mut tt_hash = [0u8; crypto::SHA256_HASH_LEN_BYTES]; tt.finish(&mut tt_hash)?; - salt.extend_from_slice(&tt_hash); + salt.extend_from_slice(&tt_hash).unwrap(); // println!("Session Key: salt: {:x?}, len: {}", salt, salt.len()); crypto::hkdf_sha256(salt.as_slice(), shared_secret, &SEKEYS_INFO, key) @@ -420,14 +413,14 @@ impl Case { if key.len() < 16 { return Err(Error::NoSpace); } - let mut salt = Vec::::with_capacity(256); - salt.extend_from_slice(ipk); + let mut salt = heapless::Vec::::new(); + salt.extend_from_slice(ipk).unwrap(); let tt = tt.clone(); let mut tt_hash = [0u8; crypto::SHA256_HASH_LEN_BYTES]; tt.finish(&mut tt_hash)?; - salt.extend_from_slice(&tt_hash); + salt.extend_from_slice(&tt_hash).unwrap(); // println!("Sigma3Key: salt: {:x?}, len: {}", salt, salt.len()); crypto::hkdf_sha256(salt.as_slice(), shared_secret, &S3K_INFO, key) @@ -447,16 +440,16 @@ impl Case { if key.len() < 16 { return Err(Error::NoSpace); } - let mut salt = Vec::::with_capacity(256); - salt.extend_from_slice(ipk); - salt.extend_from_slice(our_random); - salt.extend_from_slice(&case_session.our_pub_key); + let mut salt = heapless::Vec::::new(); + salt.extend_from_slice(ipk).unwrap(); + salt.extend_from_slice(our_random).unwrap(); + salt.extend_from_slice(&case_session.our_pub_key).unwrap(); let tt = case_session.tt_hash.clone(); let mut tt_hash = [0u8; crypto::SHA256_HASH_LEN_BYTES]; tt.finish(&mut tt_hash)?; - salt.extend_from_slice(&tt_hash); + salt.extend_from_slice(&tt_hash).unwrap(); // println!("Sigma2Key: salt: {:x?}, len: {}", salt, salt.len()); crypto::hkdf_sha256(salt.as_slice(), &case_session.shared_secret, &S2K_INFO, key) @@ -467,17 +460,15 @@ impl Case { } fn get_sigma2_encryption( - fabric: &RwLockReadGuardRef>, + fabric: &Fabric, + rand: Rand, our_random: &[u8], case_session: &mut CaseSession, signature: &[u8], out: &mut [u8], ) -> Result { let mut resumption_id: [u8; 16] = [0; 16]; - rand::thread_rng().fill_bytes(&mut resumption_id); - - // We are guaranteed this unwrap will work - let fabric = fabric.as_ref().as_ref().unwrap(); + rand(&mut resumption_id); let mut sigma2_key = [0_u8; crypto::SYMM_KEY_LEN_BYTES]; Case::get_sigma2_key( @@ -487,7 +478,7 @@ impl Case { &mut sigma2_key, )?; - let mut write_buf = WriteBuf::new(out, out.len()); + let mut write_buf = WriteBuf::new(out); let mut tw = TLVWriter::new(&mut write_buf); tw.start_struct(TagType::Anonymous)?; tw.str16_as(TagType::Context(1), |buf| fabric.noc.as_tlv(buf))?; @@ -517,20 +508,19 @@ impl Case { cipher_text, cipher_text.len() - TAG_LEN, )?; - Ok(write_buf.as_slice().len()) + Ok(write_buf.into_slice().len()) } fn get_sigma2_sign( - fabric: &RwLockReadGuardRef>, + fabric: &Fabric, our_pub_key: &[u8], peer_pub_key: &[u8], signature: &mut [u8], ) -> Result { // We are guaranteed this unwrap will work - let fabric = fabric.as_ref().as_ref().unwrap(); const MAX_TBS_SIZE: usize = 800; - let mut buf: [u8; MAX_TBS_SIZE] = [0; MAX_TBS_SIZE]; - let mut write_buf = WriteBuf::new(&mut buf, MAX_TBS_SIZE); + let mut buf = [0; MAX_TBS_SIZE]; + let mut write_buf = WriteBuf::new(&mut buf); let mut tw = TLVWriter::new(&mut write_buf); tw.start_struct(TagType::Anonymous)?; tw.str16_as(TagType::Context(1), |buf| fabric.noc.as_tlv(buf))?; @@ -541,7 +531,7 @@ impl Case { tw.str8(TagType::Context(4), peer_pub_key)?; tw.end_container()?; //println!("TBS is {:x?}", write_buf.as_borrow_slice()); - fabric.sign_msg(write_buf.as_slice(), signature) + fabric.sign_msg(write_buf.into_slice(), signature) } } diff --git a/matter/src/secure_channel/common.rs b/matter/src/secure_channel/common.rs index 511bb5cb..7049ba38 100644 --- a/matter/src/secure_channel/common.rs +++ b/matter/src/secure_channel/common.rs @@ -15,23 +15,14 @@ * limitations under the License. */ -use boxslab::Slab; -use log::info; use num_derive::FromPrimitive; -use crate::{ - error::Error, - transport::{ - exchange::Exchange, - packet::{Packet, PacketPool}, - session::SessionHandle, - }, -}; +use crate::{error::Error, transport::packet::Packet}; use super::status_report::{create_status_report, GeneralCode}; /* Interaction Model ID as per the Matter Spec */ -pub const PROTO_ID_SECURE_CHANNEL: usize = 0x00; +pub const PROTO_ID_SECURE_CHANNEL: u16 = 0x00; #[derive(FromPrimitive, Debug)] pub enum OpCode { @@ -88,14 +79,15 @@ pub fn create_sc_status_report( } pub fn create_mrp_standalone_ack(proto_tx: &mut Packet) { - proto_tx.set_proto_id(PROTO_ID_SECURE_CHANNEL as u16); + proto_tx.set_proto_id(PROTO_ID_SECURE_CHANNEL); proto_tx.set_proto_opcode(OpCode::MRPStandAloneAck as u8); proto_tx.unset_reliable(); } -pub fn send_mrp_standalone_ack(exch: &mut Exchange, sess: &mut SessionHandle) -> Result<(), Error> { - info!("Sending standalone ACK"); - let mut ack_packet = Slab::::try_new(Packet::new_tx()?).ok_or(Error::NoMemory)?; - create_mrp_standalone_ack(&mut ack_packet); - exch.send(ack_packet, sess) -} +// TODO +// pub fn send_mrp_standalone_ack(exch: &mut Exchange, sess: &mut SessionHandle) -> Result<(), Error> { +// info!("Sending standalone ACK"); +// let mut ack_packet = Slab::::try_new(Packet::new_tx()?).ok_or(Error::NoMemory)?; +// create_mrp_standalone_ack(&mut ack_packet); +// exch.send(ack_packet, sess) +// } diff --git a/matter/src/secure_channel/core.rs b/matter/src/secure_channel/core.rs index 9f7d16b8..5ca18042 100644 --- a/matter/src/secure_channel/core.rs +++ b/matter/src/secure_channel/core.rs @@ -15,14 +15,11 @@ * limitations under the License. */ -use std::sync::Arc; +use core::cell::RefCell; use crate::{ - error::*, - fabric::FabricMgr, - secure_channel::common::*, - tlv, - transport::proto_demux::{self, ProtoCtx, ResponseRequired}, + error::*, fabric::FabricMgr, mdns::MdnsMgr, secure_channel::common::*, tlv, + transport::proto_ctx::ProtoCtx, utils::rand::Rand, }; use log::{error, info}; use num; @@ -32,48 +29,54 @@ use super::{case::Case, pake::PaseMgr}; /* Handle messages related to the Secure Channel */ -pub struct SecureChannel { - case: Case, - pase: PaseMgr, +pub struct SecureChannel<'a> { + case: Case<'a>, + pase: &'a RefCell, + mdns: &'a RefCell>, } -impl SecureChannel { - pub fn new(pase: PaseMgr, fabric_mgr: Arc) -> SecureChannel { +impl<'a> SecureChannel<'a> { + pub fn new( + pase: &'a RefCell, + fabric_mgr: &'a RefCell, + mdns: &'a RefCell>, + rand: Rand, + ) -> Self { SecureChannel { + case: Case::new(fabric_mgr, rand), pase, - case: Case::new(fabric_mgr), + mdns, } } -} -impl proto_demux::HandleProto for SecureChannel { - fn handle_proto_id(&mut self, ctx: &mut ProtoCtx) -> Result { + pub fn handle(&mut self, ctx: &mut ProtoCtx) -> Result { let proto_opcode: OpCode = num::FromPrimitive::from_u8(ctx.rx.get_proto_opcode()).ok_or(Error::Invalid)?; - ctx.tx.set_proto_id(PROTO_ID_SECURE_CHANNEL as u16); + ctx.tx.set_proto_id(PROTO_ID_SECURE_CHANNEL); info!("Received Opcode: {:?}", proto_opcode); info!("Received Data:"); - tlv::print_tlv_list(ctx.rx.as_borrow_slice()); - let result = match proto_opcode { - OpCode::MRPStandAloneAck => Ok(ResponseRequired::No), - OpCode::PBKDFParamRequest => self.pase.pbkdfparamreq_handler(ctx), - OpCode::PASEPake1 => self.pase.pasepake1_handler(ctx), - OpCode::PASEPake3 => self.pase.pasepake3_handler(ctx), + tlv::print_tlv_list(ctx.rx.as_slice()); + let reply = match proto_opcode { + OpCode::MRPStandAloneAck => Ok(true), + OpCode::PBKDFParamRequest => self.pase.borrow_mut().pbkdfparamreq_handler(ctx), + OpCode::PASEPake1 => self.pase.borrow_mut().pasepake1_handler(ctx), + OpCode::PASEPake3 => self + .pase + .borrow_mut() + .pasepake3_handler(ctx, &mut self.mdns.borrow_mut()), OpCode::CASESigma1 => self.case.casesigma1_handler(ctx), OpCode::CASESigma3 => self.case.casesigma3_handler(ctx), _ => { error!("OpCode Not Handled: {:?}", proto_opcode); Err(Error::InvalidOpcode) } - }; - if result == Ok(ResponseRequired::Yes) { + }?; + + if reply { info!("Sending response"); - tlv::print_tlv_list(ctx.tx.as_borrow_slice()); + tlv::print_tlv_list(ctx.tx.as_mut_slice()); } - result - } - fn get_proto_id(&self) -> usize { - PROTO_ID_SECURE_CHANNEL + Ok(reply) } } diff --git a/matter/src/secure_channel/crypto.rs b/matter/src/secure_channel/crypto.rs index f9481ba2..45d83592 100644 --- a/matter/src/secure_channel/crypto.rs +++ b/matter/src/secure_channel/crypto.rs @@ -15,40 +15,15 @@ * limitations under the License. */ -use crate::error::Error; - -// This trait allows us to switch between crypto providers like OpenSSL and mbedTLS for Spake2 -// Currently this is only validate for a verifier(responder) - -// A verifier will typically do: -// Step 1: w0 and L -// set_w0_from_w0s -// set_L -// Step 2: get_pB -// Step 3: get_TT_as_verifier(pA) -// Step 4: Computation of cA and cB happens outside since it doesn't use either BigNum or EcPoint -pub trait CryptoSpake2 { - fn new() -> Result - where - Self: Sized; - - fn set_w0_from_w0s(&mut self, w0s: &[u8]) -> Result<(), Error>; - fn set_w1_from_w1s(&mut self, w1s: &[u8]) -> Result<(), Error>; - fn set_w0(&mut self, w0: &[u8]) -> Result<(), Error>; - fn set_w1(&mut self, w1: &[u8]) -> Result<(), Error>; - - #[allow(non_snake_case)] - fn set_L(&mut self, l: &[u8]) -> Result<(), Error>; - #[allow(non_snake_case)] - fn set_L_from_w1s(&mut self, w1s: &[u8]) -> Result<(), Error>; - #[allow(non_snake_case)] - fn get_pB(&mut self, pB: &mut [u8]) -> Result<(), Error>; - #[allow(non_snake_case)] - fn get_TT_as_verifier( - &mut self, - context: &[u8], - pA: &[u8], - pB: &[u8], - out: &mut [u8], - ) -> Result<(), Error>; -} +#[cfg(not(any( + feature = "crypto_openssl", + feature = "crypto_mbedtls", + feature = "crypto_esp_mbedtls" +)))] +pub use super::crypto_dummy::CryptoSpake2; +#[cfg(feature = "crypto_esp_mbedtls")] +pub use super::crypto_esp_mbedtls::CryptoSpake2; +#[cfg(feature = "crypto_mbedtls")] +pub use super::crypto_mbedtls::CryptoSpake2; +#[cfg(feature = "crypto_openssl")] +pub use super::crypto_openssl::CryptoSpake2; diff --git a/matter/src/secure_channel/crypto_dummy.rs b/matter/src/secure_channel/crypto_dummy.rs new file mode 100644 index 00000000..11ec8523 --- /dev/null +++ b/matter/src/secure_channel/crypto_dummy.rs @@ -0,0 +1,73 @@ +/* + * + * Copyright (c) 2020-2022 Project CHIP Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +use crate::error::Error; + +#[allow(non_snake_case)] + +pub struct CryptoSpake2 {} + +impl CryptoSpake2 { + #[allow(non_snake_case)] + pub fn new() -> Result { + Ok(Self {}) + } + + // Computes w0 from w0s respectively + pub fn set_w0_from_w0s(&mut self, _w0s: &[u8]) -> Result<(), Error> { + Err(Error::Invalid) + } + + pub fn set_w1_from_w1s(&mut self, _w1s: &[u8]) -> Result<(), Error> { + Err(Error::Invalid) + } + + pub fn set_w0(&mut self, _w0: &[u8]) -> Result<(), Error> { + Err(Error::Invalid) + } + + pub fn set_w1(&mut self, _w1: &[u8]) -> Result<(), Error> { + Err(Error::Invalid) + } + + #[allow(non_snake_case)] + pub fn set_L(&mut self, _l: &[u8]) -> Result<(), Error> { + Err(Error::Invalid) + } + + #[allow(non_snake_case)] + #[allow(dead_code)] + pub fn set_L_from_w1s(&mut self, _w1s: &[u8]) -> Result<(), Error> { + Err(Error::Invalid) + } + + #[allow(non_snake_case)] + pub fn get_pB(&mut self, _pB: &mut [u8]) -> Result<(), Error> { + Err(Error::Invalid) + } + + #[allow(non_snake_case)] + pub fn get_TT_as_verifier( + &mut self, + _context: &[u8], + _pA: &[u8], + _pB: &[u8], + _out: &mut [u8], + ) -> Result<(), Error> { + Err(Error::Invalid) + } +} diff --git a/matter/src/secure_channel/crypto_esp_mbedtls.rs b/matter/src/secure_channel/crypto_esp_mbedtls.rs index 632be2c6..316276ba 100644 --- a/matter/src/secure_channel/crypto_esp_mbedtls.rs +++ b/matter/src/secure_channel/crypto_esp_mbedtls.rs @@ -17,8 +17,6 @@ use crate::error::Error; -use super::crypto::CryptoSpake2; - const MATTER_M_BIN: [u8; 65] = [ 0x04, 0x88, 0x6e, 0x2f, 0x97, 0xac, 0xe4, 0x6e, 0x55, 0xba, 0x9d, 0xd7, 0x24, 0x25, 0x79, 0xf2, 0x99, 0x3b, 0x64, 0xe1, 0x6e, 0xf3, 0xdc, 0xab, 0x95, 0xaf, 0xd4, 0x97, 0x33, 0x3d, 0x8f, 0xa1, @@ -36,16 +34,16 @@ const MATTER_N_BIN: [u8; 65] = [ #[allow(non_snake_case)] -pub struct CryptoEspMbedTls {} +pub struct CryptoSpake2 {} -impl CryptoSpake2 for CryptoEspMbedTls { +impl CryptoSpake2 { #[allow(non_snake_case)] - fn new() -> Result { - Ok(CryptoEspMbedTls {}) + pub fn new() -> Result { + Ok(Self {}) } // Computes w0 from w0s respectively - fn set_w0_from_w0s(&mut self, w0s: &[u8]) -> Result<(), Error> { + pub fn set_w0_from_w0s(&mut self, w0s: &[u8]) -> Result<(), Error> { // From the Matter Spec, // w0 = w0s mod p // where p is the order of the curve @@ -53,7 +51,7 @@ impl CryptoSpake2 for CryptoEspMbedTls { Ok(()) } - fn set_w1_from_w1s(&mut self, w1s: &[u8]) -> Result<(), Error> { + pub fn set_w1_from_w1s(&mut self, w1s: &[u8]) -> Result<(), Error> { // From the Matter Spec, // w1 = w1s mod p // where p is the order of the curve @@ -61,17 +59,17 @@ impl CryptoSpake2 for CryptoEspMbedTls { Ok(()) } - fn set_w0(&mut self, w0: &[u8]) -> Result<(), Error> { + pub fn set_w0(&mut self, w0: &[u8]) -> Result<(), Error> { Ok(()) } - fn set_w1(&mut self, w1: &[u8]) -> Result<(), Error> { + pub fn set_w1(&mut self, w1: &[u8]) -> Result<(), Error> { Ok(()) } #[allow(non_snake_case)] #[allow(dead_code)] - fn set_L(&mut self, w1s: &[u8]) -> Result<(), Error> { + pub fn set_L(&mut self, w1s: &[u8]) -> Result<(), Error> { // From the Matter spec, // L = w1 * P // where P is the generator of the underlying elliptic curve @@ -79,7 +77,7 @@ impl CryptoSpake2 for CryptoEspMbedTls { } #[allow(non_snake_case)] - fn get_pB(&mut self, pB: &mut [u8]) -> Result<(), Error> { + pub fn get_pB(&mut self, pB: &mut [u8]) -> Result<(), Error> { // From the SPAKE2+ spec (https://datatracker.ietf.org/doc/draft-bar-cfrg-spake2plus/) // for y // - select random y between 0 to p @@ -90,7 +88,7 @@ impl CryptoSpake2 for CryptoEspMbedTls { } #[allow(non_snake_case)] - fn get_TT_as_verifier( + pub fn get_TT_as_verifier( &mut self, context: &[u8], pA: &[u8], @@ -101,13 +99,10 @@ impl CryptoSpake2 for CryptoEspMbedTls { } } -impl CryptoEspMbedTls {} - #[cfg(test)] mod tests { - use super::CryptoEspMbedTls; - use crate::secure_channel::crypto::CryptoSpake2; + use super::CryptoSpake2; use crate::secure_channel::spake2p_test_vectors::test_vectors::*; use openssl::bn::BigNum; use openssl::ec::{EcPoint, PointConversionForm}; @@ -116,13 +111,12 @@ mod tests { #[allow(non_snake_case)] fn test_get_X() { for t in RFC_T { - let mut c = CryptoEspMbedTls::new().unwrap(); + let mut c = CryptoSpake2::new().unwrap(); let x = BigNum::from_slice(&t.x).unwrap(); c.set_w0(&t.w0).unwrap(); let P = c.group.generator(); - let r = - CryptoEspMbedTls::do_add_mul(P, &x, &c.M, &c.w0, &c.group, &mut c.bn_ctx).unwrap(); + let r = CryptoSpake2::do_add_mul(P, &x, &c.M, &c.w0, &c.group, &mut c.bn_ctx).unwrap(); assert_eq!( t.X, r.to_bytes(&c.group, PointConversionForm::UNCOMPRESSED, &mut c.bn_ctx) @@ -136,12 +130,11 @@ mod tests { #[allow(non_snake_case)] fn test_get_Y() { for t in RFC_T { - let mut c = CryptoEspMbedTls::new().unwrap(); + let mut c = CryptoSpake2::new().unwrap(); let y = BigNum::from_slice(&t.y).unwrap(); c.set_w0(&t.w0).unwrap(); let P = c.group.generator(); - let r = - CryptoEspMbedTls::do_add_mul(P, &y, &c.N, &c.w0, &c.group, &mut c.bn_ctx).unwrap(); + let r = CryptoSpake2::do_add_mul(P, &y, &c.N, &c.w0, &c.group, &mut c.bn_ctx).unwrap(); assert_eq!( t.Y, r.to_bytes(&c.group, PointConversionForm::UNCOMPRESSED, &mut c.bn_ctx) @@ -155,12 +148,12 @@ mod tests { #[allow(non_snake_case)] fn test_get_ZV_as_prover() { for t in RFC_T { - let mut c = CryptoEspMbedTls::new().unwrap(); + let mut c = CryptoSpake2::new().unwrap(); let x = BigNum::from_slice(&t.x).unwrap(); c.set_w0(&t.w0).unwrap(); c.set_w1(&t.w1).unwrap(); let Y = EcPoint::from_bytes(&c.group, &t.Y, &mut c.bn_ctx).unwrap(); - let (Z, V) = CryptoEspMbedTls::get_ZV_as_prover( + let (Z, V) = CryptoSpake2::get_ZV_as_prover( &c.w0, &c.w1, &mut c.N, @@ -191,12 +184,12 @@ mod tests { #[allow(non_snake_case)] fn test_get_ZV_as_verifier() { for t in RFC_T { - let mut c = CryptoEspMbedTls::new().unwrap(); + let mut c = CryptoSpake2::new().unwrap(); let y = BigNum::from_slice(&t.y).unwrap(); c.set_w0(&t.w0).unwrap(); let X = EcPoint::from_bytes(&c.group, &t.X, &mut c.bn_ctx).unwrap(); let L = EcPoint::from_bytes(&c.group, &t.L, &mut c.bn_ctx).unwrap(); - let (Z, V) = CryptoEspMbedTls::get_ZV_as_verifier( + let (Z, V) = CryptoSpake2::get_ZV_as_verifier( &c.w0, &L, &mut c.M, diff --git a/matter/src/secure_channel/crypto_mbedtls.rs b/matter/src/secure_channel/crypto_mbedtls.rs index 7ac4c5a5..27c9fc61 100644 --- a/matter/src/secure_channel/crypto_mbedtls.rs +++ b/matter/src/secure_channel/crypto_mbedtls.rs @@ -15,14 +15,11 @@ * limitations under the License. */ -use std::{ - ops::{Mul, Sub}, - sync::Arc, -}; +use alloc::sync::Arc; +use core::ops::{Mul, Sub}; use crate::error::Error; -use super::crypto::CryptoSpake2; use byteorder::{ByteOrder, LittleEndian}; use log::error; use mbedtls::{ @@ -33,6 +30,8 @@ use mbedtls::{ rng::{CtrDrbg, OsEntropy}, }; +extern crate alloc; + const MATTER_M_BIN: [u8; 65] = [ 0x04, 0x88, 0x6e, 0x2f, 0x97, 0xac, 0xe4, 0x6e, 0x55, 0xba, 0x9d, 0xd7, 0x24, 0x25, 0x79, 0xf2, 0x99, 0x3b, 0x64, 0xe1, 0x6e, 0xf3, 0xdc, 0xab, 0x95, 0xaf, 0xd4, 0x97, 0x33, 0x3d, 0x8f, 0xa1, @@ -50,7 +49,7 @@ const MATTER_N_BIN: [u8; 65] = [ #[allow(non_snake_case)] -pub struct CryptoMbedTLS { +pub struct CryptoSpake2 { group: EcGroup, order: Mpi, xy: Mpi, @@ -62,15 +61,15 @@ pub struct CryptoMbedTLS { pB: EcPoint, } -impl CryptoSpake2 for CryptoMbedTLS { +impl CryptoSpake2 { #[allow(non_snake_case)] - fn new() -> Result { + pub fn new() -> Result { let group = EcGroup::new(mbedtls::pk::EcGroupId::SecP256R1)?; let order = group.order()?; let M = EcPoint::from_binary(&group, &MATTER_M_BIN)?; let N = EcPoint::from_binary(&group, &MATTER_N_BIN)?; - Ok(CryptoMbedTLS { + Ok(Self { group, order, xy: Mpi::new(0)?, @@ -84,7 +83,7 @@ impl CryptoSpake2 for CryptoMbedTLS { } // Computes w0 from w0s respectively - fn set_w0_from_w0s(&mut self, w0s: &[u8]) -> Result<(), Error> { + pub fn set_w0_from_w0s(&mut self, w0s: &[u8]) -> Result<(), Error> { // From the Matter Spec, // w0 = w0s mod p // where p is the order of the curve @@ -94,7 +93,7 @@ impl CryptoSpake2 for CryptoMbedTLS { Ok(()) } - fn set_w1_from_w1s(&mut self, w1s: &[u8]) -> Result<(), Error> { + pub fn set_w1_from_w1s(&mut self, w1s: &[u8]) -> Result<(), Error> { // From the Matter Spec, // w1 = w1s mod p // where p is the order of the curve @@ -104,24 +103,25 @@ impl CryptoSpake2 for CryptoMbedTLS { Ok(()) } - fn set_w0(&mut self, w0: &[u8]) -> Result<(), Error> { + pub fn set_w0(&mut self, w0: &[u8]) -> Result<(), Error> { self.w0 = Mpi::from_binary(w0)?; Ok(()) } - fn set_w1(&mut self, w1: &[u8]) -> Result<(), Error> { + pub fn set_w1(&mut self, w1: &[u8]) -> Result<(), Error> { self.w1 = Mpi::from_binary(w1)?; Ok(()) } - fn set_L(&mut self, l: &[u8]) -> Result<(), Error> { + #[allow(non_snake_case)] + pub fn set_L(&mut self, l: &[u8]) -> Result<(), Error> { self.L = EcPoint::from_binary(&self.group, l)?; Ok(()) } #[allow(non_snake_case)] #[allow(dead_code)] - fn set_L_from_w1s(&mut self, w1s: &[u8]) -> Result<(), Error> { + pub fn set_L_from_w1s(&mut self, w1s: &[u8]) -> Result<(), Error> { // From the Matter spec, // L = w1 * P // where P is the generator of the underlying elliptic curve @@ -132,7 +132,7 @@ impl CryptoSpake2 for CryptoMbedTLS { } #[allow(non_snake_case)] - fn get_pB(&mut self, pB: &mut [u8]) -> Result<(), Error> { + pub fn get_pB(&mut self, pB: &mut [u8]) -> Result<(), Error> { // From the SPAKE2+ spec (https://datatracker.ietf.org/doc/draft-bar-cfrg-spake2plus/) // for y // - select random y between 0 to p @@ -157,7 +157,7 @@ impl CryptoSpake2 for CryptoMbedTLS { } #[allow(non_snake_case)] - fn get_TT_as_verifier( + pub fn get_TT_as_verifier( &mut self, context: &[u8], pA: &[u8], @@ -166,21 +166,21 @@ impl CryptoSpake2 for CryptoMbedTLS { ) -> Result<(), Error> { let mut TT = Md::new(mbedtls::hash::Type::Sha256)?; // context - CryptoMbedTLS::add_to_tt(&mut TT, context)?; + Self::add_to_tt(&mut TT, context)?; // 2 empty identifiers - CryptoMbedTLS::add_to_tt(&mut TT, &[])?; - CryptoMbedTLS::add_to_tt(&mut TT, &[])?; + Self::add_to_tt(&mut TT, &[])?; + Self::add_to_tt(&mut TT, &[])?; // M - CryptoMbedTLS::add_to_tt(&mut TT, &MATTER_M_BIN)?; + Self::add_to_tt(&mut TT, &MATTER_M_BIN)?; // N - CryptoMbedTLS::add_to_tt(&mut TT, &MATTER_N_BIN)?; + Self::add_to_tt(&mut TT, &MATTER_N_BIN)?; // X = pA - CryptoMbedTLS::add_to_tt(&mut TT, pA)?; + Self::add_to_tt(&mut TT, pA)?; // Y = pB - CryptoMbedTLS::add_to_tt(&mut TT, pB)?; + Self::add_to_tt(&mut TT, pB)?; let X = EcPoint::from_binary(&self.group, pA)?; - let (Z, V) = CryptoMbedTLS::get_ZV_as_verifier( + let (Z, V) = Self::get_ZV_as_verifier( &self.w0, &self.L, &mut self.M, @@ -193,24 +193,22 @@ impl CryptoSpake2 for CryptoMbedTLS { // Z let tmp = Z.to_binary(&self.group, false)?; let tmp = tmp.as_slice(); - CryptoMbedTLS::add_to_tt(&mut TT, tmp)?; + Self::add_to_tt(&mut TT, tmp)?; // V let tmp = V.to_binary(&self.group, false)?; let tmp = tmp.as_slice(); - CryptoMbedTLS::add_to_tt(&mut TT, tmp)?; + Self::add_to_tt(&mut TT, tmp)?; // w0 let tmp = self.w0.to_binary()?; let tmp = tmp.as_slice(); - CryptoMbedTLS::add_to_tt(&mut TT, tmp)?; + Self::add_to_tt(&mut TT, tmp)?; TT.finish(out)?; Ok(()) } -} -impl CryptoMbedTLS { fn add_to_tt(tt: &mut Md, buf: &[u8]) -> Result<(), Error> { let mut len_buf: [u8; 8] = [0; 8]; LittleEndian::write_u64(&mut len_buf, buf.len() as u64); @@ -247,7 +245,7 @@ impl CryptoMbedTLS { let mut tmp = x.mul(w0)?; tmp = tmp.modulo(order)?; - let inverted_N = CryptoMbedTLS::invert(group, N)?; + let inverted_N = Self::invert(group, N)?; let Z = EcPoint::muladd(group, Y, x, &inverted_N, &tmp)?; // Cofactor for P256 is 1, so that is a No-Op @@ -283,7 +281,7 @@ impl CryptoMbedTLS { let mut tmp = y.mul(w0)?; tmp = tmp.modulo(order)?; - let inverted_M = CryptoMbedTLS::invert(group, M)?; + let inverted_M = Self::invert(group, M)?; let Z = EcPoint::muladd(group, X, y, &inverted_M, &tmp)?; // Cofactor for P256 is 1, so that is a No-Op @@ -302,8 +300,7 @@ impl CryptoMbedTLS { #[cfg(test)] mod tests { - use super::CryptoMbedTLS; - use crate::secure_channel::crypto::CryptoSpake2; + use super::CryptoSpake2; use crate::secure_channel::spake2p_test_vectors::test_vectors::*; use mbedtls::bignum::Mpi; use mbedtls::ecp::EcPoint; @@ -312,7 +309,7 @@ mod tests { #[allow(non_snake_case)] fn test_get_X() { for t in RFC_T { - let mut c = CryptoMbedTLS::new().unwrap(); + let mut c = CryptoSpake2::new().unwrap(); let x = Mpi::from_binary(&t.x).unwrap(); c.set_w0(&t.w0).unwrap(); let P = c.group.generator().unwrap(); @@ -326,7 +323,7 @@ mod tests { #[allow(non_snake_case)] fn test_get_Y() { for t in RFC_T { - let mut c = CryptoMbedTLS::new().unwrap(); + let mut c = CryptoSpake2::new().unwrap(); let y = Mpi::from_binary(&t.y).unwrap(); c.set_w0(&t.w0).unwrap(); let P = c.group.generator().unwrap(); @@ -339,12 +336,12 @@ mod tests { #[allow(non_snake_case)] fn test_get_ZV_as_prover() { for t in RFC_T { - let mut c = CryptoMbedTLS::new().unwrap(); + let mut c = CryptoSpake2::new().unwrap(); let x = Mpi::from_binary(&t.x).unwrap(); c.set_w0(&t.w0).unwrap(); c.set_w1(&t.w1).unwrap(); let Y = EcPoint::from_binary(&c.group, &t.Y).unwrap(); - let (Z, V) = CryptoMbedTLS::get_ZV_as_prover( + let (Z, V) = CryptoSpake2::get_ZV_as_prover( &c.w0, &c.w1, &mut c.N, @@ -364,12 +361,12 @@ mod tests { #[allow(non_snake_case)] fn test_get_ZV_as_verifier() { for t in RFC_T { - let mut c = CryptoMbedTLS::new().unwrap(); + let mut c = CryptoSpake2::new().unwrap(); let y = Mpi::from_binary(&t.y).unwrap(); c.set_w0(&t.w0).unwrap(); let X = EcPoint::from_binary(&c.group, &t.X).unwrap(); let L = EcPoint::from_binary(&c.group, &t.L).unwrap(); - let (Z, V) = CryptoMbedTLS::get_ZV_as_verifier( + let (Z, V) = CryptoSpake2::get_ZV_as_verifier( &c.w0, &L, &mut c.M, diff --git a/matter/src/secure_channel/crypto_openssl.rs b/matter/src/secure_channel/crypto_openssl.rs index 84d6793e..631cb6b9 100644 --- a/matter/src/secure_channel/crypto_openssl.rs +++ b/matter/src/secure_channel/crypto_openssl.rs @@ -17,7 +17,6 @@ use crate::error::Error; -use super::crypto::CryptoSpake2; use byteorder::{ByteOrder, LittleEndian}; use log::error; use openssl::{ @@ -44,7 +43,7 @@ const MATTER_N_BIN: [u8; 65] = [ #[allow(non_snake_case)] -pub struct CryptoOpenSSL { +pub struct CryptoSpake2 { group: EcGroup, bn_ctx: BigNumContext, // Stores the randomly generated x or y depending upon who we are @@ -58,9 +57,9 @@ pub struct CryptoOpenSSL { order: BigNum, } -impl CryptoSpake2 for CryptoOpenSSL { +impl CryptoSpake2 { #[allow(non_snake_case)] - fn new() -> Result { + pub fn new() -> Result { let group = EcGroup::from_curve_name(Nid::X9_62_PRIME256V1)?; let mut bn_ctx = BigNumContext::new()?; let M = EcPoint::from_bytes(&group, &MATTER_M_BIN, &mut bn_ctx)?; @@ -70,7 +69,7 @@ impl CryptoSpake2 for CryptoOpenSSL { let mut order = BigNum::new()?; group.as_ref().order(&mut order, &mut bn_ctx)?; - Ok(CryptoOpenSSL { + Ok(Self { group, bn_ctx, xy: BigNum::new()?, @@ -85,7 +84,7 @@ impl CryptoSpake2 for CryptoOpenSSL { } // Computes w0 from w0s respectively - fn set_w0_from_w0s(&mut self, w0s: &[u8]) -> Result<(), Error> { + pub fn set_w0_from_w0s(&mut self, w0s: &[u8]) -> Result<(), Error> { // From the Matter Spec, // w0 = w0s mod p // where p is the order of the curve @@ -96,7 +95,7 @@ impl CryptoSpake2 for CryptoOpenSSL { Ok(()) } - fn set_w1_from_w1s(&mut self, w1s: &[u8]) -> Result<(), Error> { + pub fn set_w1_from_w1s(&mut self, w1s: &[u8]) -> Result<(), Error> { // From the Matter Spec, // w1 = w1s mod p // where p is the order of the curve @@ -107,24 +106,24 @@ impl CryptoSpake2 for CryptoOpenSSL { Ok(()) } - fn set_w0(&mut self, w0: &[u8]) -> Result<(), Error> { + pub fn set_w0(&mut self, w0: &[u8]) -> Result<(), Error> { self.w0 = BigNum::from_slice(w0)?; Ok(()) } - fn set_w1(&mut self, w1: &[u8]) -> Result<(), Error> { + pub fn set_w1(&mut self, w1: &[u8]) -> Result<(), Error> { self.w1 = BigNum::from_slice(w1)?; Ok(()) } - fn set_L(&mut self, l: &[u8]) -> Result<(), Error> { + pub fn set_L(&mut self, l: &[u8]) -> Result<(), Error> { self.L = EcPoint::from_bytes(&self.group, l, &mut self.bn_ctx)?; Ok(()) } #[allow(non_snake_case)] #[allow(dead_code)] - fn set_L_from_w1s(&mut self, w1s: &[u8]) -> Result<(), Error> { + pub fn set_L_from_w1s(&mut self, w1s: &[u8]) -> Result<(), Error> { // From the Matter spec, // L = w1 * P // where P is the generator of the underlying elliptic curve @@ -135,7 +134,7 @@ impl CryptoSpake2 for CryptoOpenSSL { } #[allow(non_snake_case)] - fn get_pB(&mut self, pB: &mut [u8]) -> Result<(), Error> { + pub fn get_pB(&mut self, pB: &mut [u8]) -> Result<(), Error> { // From the SPAKE2+ spec (https://datatracker.ietf.org/doc/draft-bar-cfrg-spake2plus/) // for y // - select random y between 0 to p @@ -143,7 +142,7 @@ impl CryptoSpake2 for CryptoOpenSSL { // - pB = Y self.order.rand_range(&mut self.xy)?; let P = self.group.generator(); - self.pB = CryptoOpenSSL::do_add_mul( + self.pB = Self::do_add_mul( P, &self.xy, &self.N, @@ -166,7 +165,7 @@ impl CryptoSpake2 for CryptoOpenSSL { } #[allow(non_snake_case)] - fn get_TT_as_verifier( + pub fn get_TT_as_verifier( &mut self, context: &[u8], pA: &[u8], @@ -175,21 +174,21 @@ impl CryptoSpake2 for CryptoOpenSSL { ) -> Result<(), Error> { let mut TT = Hasher::new(MessageDigest::sha256())?; // context - CryptoOpenSSL::add_to_tt(&mut TT, context)?; + Self::add_to_tt(&mut TT, context)?; // 2 empty identifiers - CryptoOpenSSL::add_to_tt(&mut TT, &[])?; - CryptoOpenSSL::add_to_tt(&mut TT, &[])?; + Self::add_to_tt(&mut TT, &[])?; + Self::add_to_tt(&mut TT, &[])?; // M - CryptoOpenSSL::add_to_tt(&mut TT, &MATTER_M_BIN)?; + Self::add_to_tt(&mut TT, &MATTER_M_BIN)?; // N - CryptoOpenSSL::add_to_tt(&mut TT, &MATTER_N_BIN)?; + Self::add_to_tt(&mut TT, &MATTER_N_BIN)?; // X = pA - CryptoOpenSSL::add_to_tt(&mut TT, pA)?; + Self::add_to_tt(&mut TT, pA)?; // Y = pB - CryptoOpenSSL::add_to_tt(&mut TT, pB)?; + Self::add_to_tt(&mut TT, pB)?; let X = EcPoint::from_bytes(&self.group, pA, &mut self.bn_ctx)?; - let (Z, V) = CryptoOpenSSL::get_ZV_as_verifier( + let (Z, V) = Self::get_ZV_as_verifier( &self.w0, &self.L, &mut self.M, @@ -207,7 +206,7 @@ impl CryptoSpake2 for CryptoOpenSSL { &mut self.bn_ctx, )?; let tmp = tmp.as_slice(); - CryptoOpenSSL::add_to_tt(&mut TT, tmp)?; + Self::add_to_tt(&mut TT, tmp)?; // V let tmp = V.to_bytes( @@ -216,20 +215,18 @@ impl CryptoSpake2 for CryptoOpenSSL { &mut self.bn_ctx, )?; let tmp = tmp.as_slice(); - CryptoOpenSSL::add_to_tt(&mut TT, tmp)?; + Self::add_to_tt(&mut TT, tmp)?; // w0 let tmp = self.w0.to_vec(); let tmp = tmp.as_slice(); - CryptoOpenSSL::add_to_tt(&mut TT, tmp)?; + Self::add_to_tt(&mut TT, tmp)?; let h = TT.finish()?; TT_hash.copy_from_slice(h.as_ref()); Ok(()) } -} -impl CryptoOpenSSL { fn add_to_tt(tt: &mut Hasher, buf: &[u8]) -> Result<(), Error> { let mut len_buf: [u8; 8] = [0; 8]; LittleEndian::write_u64(&mut len_buf, buf.len() as u64); @@ -286,11 +283,11 @@ impl CryptoOpenSSL { let mut tmp = BigNum::new()?; tmp.mod_mul(x, w0, order, bn_ctx)?; N.invert(group, bn_ctx)?; - let Z = CryptoOpenSSL::do_add_mul(Y, x, N, &tmp, group, bn_ctx)?; + let Z = Self::do_add_mul(Y, x, N, &tmp, group, bn_ctx)?; // Cofactor for P256 is 1, so that is a No-Op tmp.mod_mul(w1, w0, order, bn_ctx)?; - let V = CryptoOpenSSL::do_add_mul(Y, w1, N, &tmp, group, bn_ctx)?; + let V = Self::do_add_mul(Y, w1, N, &tmp, group, bn_ctx)?; Ok((Z, V)) } @@ -321,7 +318,7 @@ impl CryptoOpenSSL { let mut tmp = BigNum::new()?; tmp.mod_mul(y, w0, order, bn_ctx)?; M.invert(group, bn_ctx)?; - let Z = CryptoOpenSSL::do_add_mul(X, y, M, &tmp, group, bn_ctx)?; + let Z = Self::do_add_mul(X, y, M, &tmp, group, bn_ctx)?; // Cofactor for P256 is 1, so that is a No-Op let mut V = EcPoint::new(group)?; @@ -333,7 +330,7 @@ impl CryptoOpenSSL { #[cfg(test)] mod tests { - use super::CryptoOpenSSL; + use super::CryptoSpake2; use crate::secure_channel::crypto::CryptoSpake2; use crate::secure_channel::spake2p_test_vectors::test_vectors::*; use openssl::bn::BigNum; @@ -343,12 +340,12 @@ mod tests { #[allow(non_snake_case)] fn test_get_X() { for t in RFC_T { - let mut c = CryptoOpenSSL::new().unwrap(); + let mut c = CryptoSpake2::new().unwrap(); let x = BigNum::from_slice(&t.x).unwrap(); c.set_w0(&t.w0).unwrap(); let P = c.group.generator(); - let r = CryptoOpenSSL::do_add_mul(P, &x, &c.M, &c.w0, &c.group, &mut c.bn_ctx).unwrap(); + let r = CryptoSpake2::do_add_mul(P, &x, &c.M, &c.w0, &c.group, &mut c.bn_ctx).unwrap(); assert_eq!( t.X, r.to_bytes(&c.group, PointConversionForm::UNCOMPRESSED, &mut c.bn_ctx) @@ -362,11 +359,11 @@ mod tests { #[allow(non_snake_case)] fn test_get_Y() { for t in RFC_T { - let mut c = CryptoOpenSSL::new().unwrap(); + let mut c = CryptoSpake2::new().unwrap(); let y = BigNum::from_slice(&t.y).unwrap(); c.set_w0(&t.w0).unwrap(); let P = c.group.generator(); - let r = CryptoOpenSSL::do_add_mul(P, &y, &c.N, &c.w0, &c.group, &mut c.bn_ctx).unwrap(); + let r = CryptoSpake2::do_add_mul(P, &y, &c.N, &c.w0, &c.group, &mut c.bn_ctx).unwrap(); assert_eq!( t.Y, r.to_bytes(&c.group, PointConversionForm::UNCOMPRESSED, &mut c.bn_ctx) @@ -380,12 +377,12 @@ mod tests { #[allow(non_snake_case)] fn test_get_ZV_as_prover() { for t in RFC_T { - let mut c = CryptoOpenSSL::new().unwrap(); + let mut c = CryptoSpake2::new().unwrap(); let x = BigNum::from_slice(&t.x).unwrap(); c.set_w0(&t.w0).unwrap(); c.set_w1(&t.w1).unwrap(); let Y = EcPoint::from_bytes(&c.group, &t.Y, &mut c.bn_ctx).unwrap(); - let (Z, V) = CryptoOpenSSL::get_ZV_as_prover( + let (Z, V) = CryptoSpake2::get_ZV_as_prover( &c.w0, &c.w1, &mut c.N, @@ -416,12 +413,12 @@ mod tests { #[allow(non_snake_case)] fn test_get_ZV_as_verifier() { for t in RFC_T { - let mut c = CryptoOpenSSL::new().unwrap(); + let mut c = CryptoSpake2::new().unwrap(); let y = BigNum::from_slice(&t.y).unwrap(); c.set_w0(&t.w0).unwrap(); let X = EcPoint::from_bytes(&c.group, &t.X, &mut c.bn_ctx).unwrap(); let L = EcPoint::from_bytes(&c.group, &t.L, &mut c.bn_ctx).unwrap(); - let (Z, V) = CryptoOpenSSL::get_ZV_as_verifier( + let (Z, V) = CryptoSpake2::get_ZV_as_verifier( &c.w0, &L, &mut c.M, diff --git a/matter/src/secure_channel/mod.rs b/matter/src/secure_channel/mod.rs index 9328b253..15417b3b 100644 --- a/matter/src/secure_channel/mod.rs +++ b/matter/src/secure_channel/mod.rs @@ -17,10 +17,17 @@ pub mod case; pub mod common; +#[cfg(not(any( + feature = "crypto_openssl", + feature = "crypto_mbedtls", + feature = "crypto_esp_mbedtls", + feature = "crypto_rustcrypto" +)))] +mod crypto_dummy; #[cfg(feature = "crypto_esp_mbedtls")] -pub mod crypto_esp_mbedtls; +mod crypto_esp_mbedtls; #[cfg(feature = "crypto_mbedtls")] -pub mod crypto_mbedtls; +mod crypto_mbedtls; #[cfg(feature = "crypto_openssl")] pub mod crypto_openssl; #[cfg(feature = "crypto_rustcrypto")] diff --git a/matter/src/secure_channel/pake.rs b/matter/src/secure_channel/pake.rs index 27aaeb0f..ce05fb65 100644 --- a/matter/src/secure_channel/pake.rs +++ b/matter/src/secure_channel/pake.rs @@ -15,10 +15,7 @@ * limitations under the License. */ -use std::{ - sync::{Arc, Mutex}, - time::{Duration, SystemTime}, -}; +use core::{fmt::Write, time::Duration}; use super::{ common::{create_sc_status_report, SCStatusCodes}, @@ -27,97 +24,115 @@ use super::{ use crate::{ crypto, error::Error, - mdns::{self, Mdns}, + mdns::{MdnsMgr, ServiceMode}, secure_channel::common::OpCode, - sys::SysMdnsService, tlv::{self, get_root_node_struct, FromTLV, OctetStr, TLVElement, TLVWriter, TagType, ToTLV}, transport::{ exchange::ExchangeCtx, network::Address, - proto_demux::{ProtoCtx, ResponseRequired}, + proto_ctx::ProtoCtx, queue::{Msg, WorkQ}, session::{CloneData, SessionMode}, }, + utils::{epoch::Epoch, rand::Rand}, }; use log::{error, info}; -use rand::prelude::*; +#[allow(clippy::large_enum_variant)] enum PaseMgrState { - Enabled(PAKE, SysMdnsService), + Enabled(Pake, heapless::String<16>, u16), Disabled, } -pub struct PaseMgrInternal { +// Could this lock be avoided? +pub struct PaseMgr { state: PaseMgrState, + epoch: Epoch, + rand: Rand, } -#[derive(Clone)] -// Could this lock be avoided? -pub struct PaseMgr(Arc>); - impl PaseMgr { - pub fn new() -> Self { - Self(Arc::new(Mutex::new(PaseMgrInternal { + pub fn new(epoch: Epoch, rand: Rand) -> Self { + Self { state: PaseMgrState::Disabled, - }))) + epoch, + rand, + } } pub fn enable_pase_session( &mut self, verifier: VerifierData, discriminator: u16, + mdns: &mut MdnsMgr, ) -> Result<(), Error> { - let mut s = self.0.lock().unwrap(); - let name: u64 = rand::thread_rng().gen_range(0..0xFFFFFFFFFFFFFFFF); - let name = format!("{:016X}", name); - let mdns = Mdns::get()? - .publish_service(&name, mdns::ServiceMode::Commissionable(discriminator))?; - s.state = PaseMgrState::Enabled(PAKE::new(verifier), mdns); + let mut buf = [0; 8]; + (self.rand)(&mut buf); + let num = u64::from_be_bytes(buf); + + let mut mdns_service_name = heapless::String::<16>::new(); + write!(&mut mdns_service_name, "{:016X}", num).unwrap(); + + mdns.publish_service( + &mdns_service_name, + ServiceMode::Commissionable(discriminator), + )?; + self.state = PaseMgrState::Enabled( + Pake::new(verifier, self.epoch, self.rand), + mdns_service_name, + discriminator, + ); + Ok(()) } - pub fn disable_pase_session(&mut self) { - let mut s = self.0.lock().unwrap(); - s.state = PaseMgrState::Disabled; + pub fn disable_pase_session(&mut self, mdns: &mut MdnsMgr) -> Result<(), Error> { + if let PaseMgrState::Enabled(_, mdns_service_name, discriminator) = &self.state { + mdns.unpublish_service( + mdns_service_name, + ServiceMode::Commissionable(*discriminator), + )?; + } + + self.state = PaseMgrState::Disabled; + + Ok(()) } /// If the PASE Session is enabled, execute the closure, /// if not enabled, generate SC Status Report fn if_enabled(&mut self, ctx: &mut ProtoCtx, f: F) -> Result<(), Error> where - F: FnOnce(&mut PAKE, &mut ProtoCtx) -> Result<(), Error>, + F: FnOnce(&mut Pake, &mut ProtoCtx) -> Result<(), Error>, { - let mut s = self.0.lock().unwrap(); - if let PaseMgrState::Enabled(pake, _) = &mut s.state { + if let PaseMgrState::Enabled(pake, _, _) = &mut self.state { f(pake, ctx) } else { error!("PASE Not enabled"); - create_sc_status_report(&mut ctx.tx, SCStatusCodes::InvalidParameter, None) + create_sc_status_report(ctx.tx, SCStatusCodes::InvalidParameter, None) } } - pub fn pbkdfparamreq_handler(&mut self, ctx: &mut ProtoCtx) -> Result { + pub fn pbkdfparamreq_handler(&mut self, ctx: &mut ProtoCtx) -> Result { ctx.tx.set_proto_opcode(OpCode::PBKDFParamResponse as u8); self.if_enabled(ctx, |pake, ctx| pake.handle_pbkdfparamrequest(ctx))?; - Ok(ResponseRequired::Yes) + Ok(true) } - pub fn pasepake1_handler(&mut self, ctx: &mut ProtoCtx) -> Result { + pub fn pasepake1_handler(&mut self, ctx: &mut ProtoCtx) -> Result { ctx.tx.set_proto_opcode(OpCode::PASEPake2 as u8); self.if_enabled(ctx, |pake, ctx| pake.handle_pasepake1(ctx))?; - Ok(ResponseRequired::Yes) + Ok(true) } - pub fn pasepake3_handler(&mut self, ctx: &mut ProtoCtx) -> Result { + pub fn pasepake3_handler( + &mut self, + ctx: &mut ProtoCtx, + mdns: &mut MdnsMgr, + ) -> Result { self.if_enabled(ctx, |pake, ctx| pake.handle_pasepake3(ctx))?; - self.disable_pase_session(); - Ok(ResponseRequired::Yes) - } -} - -impl Default for PaseMgr { - fn default() -> Self { - Self::new() + self.disable_pase_session(mdns)?; + Ok(true) } } @@ -131,30 +146,31 @@ const PASE_DISCARD_TIMEOUT_SECS: Duration = Duration::from_secs(60); const SPAKE2_SESSION_KEYS_INFO: [u8; 11] = *b"SessionKeys"; struct SessionData { - start_time: SystemTime, + start_time: Duration, exch_id: u16, peer_addr: Address, - spake2p: Box, + spake2p: Spake2P, } impl SessionData { - fn is_sess_expired(&self) -> Result { - if SystemTime::now().duration_since(self.start_time)? > PASE_DISCARD_TIMEOUT_SECS { - Ok(true) - } else { - Ok(false) - } + fn is_sess_expired(&self, epoch: Epoch) -> Result { + Ok(epoch() - self.start_time > PASE_DISCARD_TIMEOUT_SECS) } } +#[allow(clippy::large_enum_variant)] enum PakeState { Idle, InProgress(SessionData), } impl PakeState { + const fn new() -> Self { + Self::Idle + } + fn take(&mut self) -> Result { - let new = std::mem::replace(self, PakeState::Idle); + let new = core::mem::replace(self, PakeState::Idle); if let PakeState::InProgress(s) = new { Ok(s) } else { @@ -163,7 +179,7 @@ impl PakeState { } fn is_idle(&self) -> bool { - std::mem::discriminant(self) == std::mem::discriminant(&PakeState::Idle) + core::mem::discriminant(self) == core::mem::discriminant(&PakeState::Idle) } fn take_sess_data(&mut self, exch_ctx: &ExchangeCtx) -> Result { @@ -175,9 +191,9 @@ impl PakeState { } } - fn make_in_progress(&mut self, spake2p: Box, exch_ctx: &ExchangeCtx) { + fn make_in_progress(&mut self, epoch: Epoch, spake2p: Spake2P, exch_ctx: &ExchangeCtx) { *self = PakeState::InProgress(SessionData { - start_time: SystemTime::now(), + start_time: epoch(), spake2p, exch_id: exch_ctx.exch.get_id(), peer_addr: exch_ctx.sess.get_peer_addr(), @@ -191,21 +207,25 @@ impl PakeState { impl Default for PakeState { fn default() -> Self { - Self::Idle + Self::new() } } -pub struct PAKE { - pub verifier: VerifierData, +struct Pake { + verifier: VerifierData, state: PakeState, + epoch: Epoch, + rand: Rand, } -impl PAKE { - pub fn new(verifier: VerifierData) -> Self { +impl Pake { + pub fn new(verifier: VerifierData, epoch: Epoch, rand: Rand) -> Self { // TODO: Can any PBKDF2 calculation be pre-computed here - PAKE { + Self { verifier, - state: Default::default(), + state: PakeState::new(), + epoch, + rand, } } @@ -213,14 +233,14 @@ impl PAKE { pub fn handle_pasepake3(&mut self, ctx: &mut ProtoCtx) -> Result<(), Error> { let mut sd = self.state.take_sess_data(&ctx.exch_ctx)?; - let cA = extract_pasepake_1_or_3_params(ctx.rx.as_borrow_slice())?; - let (status_code, Ke) = sd.spake2p.handle_cA(cA); + let cA = extract_pasepake_1_or_3_params(ctx.rx.as_slice())?; + let (status_code, ke) = sd.spake2p.handle_cA(cA); if status_code == SCStatusCodes::SessionEstablishmentSuccess { // Get the keys - let Ke = Ke.ok_or(Error::Invalid)?; + let ke = ke.ok_or(Error::Invalid)?; let mut session_keys: [u8; 48] = [0; 48]; - crypto::hkdf_sha256(&[], Ke, &SPAKE2_SESSION_KEYS_INFO, &mut session_keys) + crypto::hkdf_sha256(&[], ke, &SPAKE2_SESSION_KEYS_INFO, &mut session_keys) .map_err(|_x| Error::NoSpace)?; // Create a session @@ -245,7 +265,7 @@ impl PAKE { WorkQ::get()?.sync_send(Msg::NewSession(clone_data))?; } - create_sc_status_report(&mut ctx.tx, status_code, None)?; + create_sc_status_report(ctx.tx, status_code, None)?; ctx.exch_ctx.exch.close(); Ok(()) } @@ -254,7 +274,7 @@ impl PAKE { pub fn handle_pasepake1(&mut self, ctx: &mut ProtoCtx) -> Result<(), Error> { let mut sd = self.state.take_sess_data(&ctx.exch_ctx)?; - let pA = extract_pasepake_1_or_3_params(ctx.rx.as_borrow_slice())?; + let pA = extract_pasepake_1_or_3_params(ctx.rx.as_slice())?; let mut pB: [u8; 65] = [0; 65]; let mut cB: [u8; 32] = [0; 32]; sd.spake2p.start_verifier(&self.verifier)?; @@ -275,18 +295,18 @@ impl PAKE { pub fn handle_pbkdfparamrequest(&mut self, ctx: &mut ProtoCtx) -> Result<(), Error> { if !self.state.is_idle() { let sd = self.state.take()?; - if sd.is_sess_expired()? { + if sd.is_sess_expired(self.epoch)? { info!("Previous session expired, clearing it"); self.state = PakeState::Idle; } else { info!("Previous session in-progress, denying new request"); // little-endian timeout (here we've hardcoded 500ms) - create_sc_status_report(&mut ctx.tx, SCStatusCodes::Busy, Some(&[0xf4, 0x01]))?; + create_sc_status_report(ctx.tx, SCStatusCodes::Busy, Some(&[0xf4, 0x01]))?; return Ok(()); } } - let root = tlv::get_root_node(ctx.rx.as_borrow_slice())?; + let root = tlv::get_root_node(ctx.rx.as_slice())?; let a = PBKDFParamReq::from_tlv(&root)?; if a.passcode_id != 0 { error!("Can't yet handle passcode_id != 0"); @@ -294,11 +314,11 @@ impl PAKE { } let mut our_random: [u8; 32] = [0; 32]; - rand::thread_rng().fill_bytes(&mut our_random); + (self.rand)(&mut our_random); let local_sessid = ctx.exch_ctx.sess.reserve_new_sess_id(); let spake2p_data: u32 = ((local_sessid as u32) << 16) | a.initiator_ssid as u32; - let mut spake2p = Box::new(Spake2P::new()); + let mut spake2p = Spake2P::new(); spake2p.set_app_data(spake2p_data); // Generate response @@ -318,8 +338,9 @@ impl PAKE { } resp.to_tlv(&mut tw, TagType::Anonymous)?; - spake2p.set_context(ctx.rx.as_borrow_slice(), ctx.tx.as_borrow_slice())?; - self.state.make_in_progress(spake2p, &ctx.exch_ctx); + spake2p.set_context(ctx.rx.as_slice(), ctx.tx.as_mut_slice())?; + self.state + .make_in_progress(self.epoch, spake2p, &ctx.exch_ctx); Ok(()) } diff --git a/matter/src/secure_channel/spake2p.rs b/matter/src/secure_channel/spake2p.rs index 335d4651..ba948f51 100644 --- a/matter/src/secure_channel/spake2p.rs +++ b/matter/src/secure_channel/spake2p.rs @@ -18,10 +18,10 @@ use crate::{ crypto::{self, HmacSha256}, sys, + utils::rand::Rand, }; use byteorder::{ByteOrder, LittleEndian}; use log::error; -use rand::prelude::*; use subtle::ConstantTimeEq; use crate::{ @@ -29,18 +29,6 @@ use crate::{ error::Error, }; -#[cfg(feature = "crypto_openssl")] -use super::crypto_openssl::CryptoOpenSSL; - -#[cfg(feature = "crypto_mbedtls")] -use super::crypto_mbedtls::CryptoMbedTLS; - -#[cfg(feature = "crypto_esp_mbedtls")] -use super::crypto_esp_mbedtls::CryptoEspMbedTls; - -#[cfg(feature = "crypto_rustcrypto")] -use super::crypto_rustcrypto::CryptoRustCrypto; - use super::{common::SCStatusCodes, crypto::CryptoSpake2}; // This file handle Spake2+ specific instructions. In itself, this file is @@ -74,7 +62,7 @@ pub struct Spake2P { context: Option, Ke: [u8; 16], cA: [u8; 32], - crypto_spake2: Option>, + crypto_spake2: Option, app_data: u32, } @@ -87,24 +75,8 @@ const CRYPTO_PUBLIC_KEY_SIZE_BYTES: usize = (2 * CRYPTO_GROUP_SIZE_BYTES) + 1; const MAX_SALT_SIZE_BYTES: usize = 32; const VERIFIER_SIZE_BYTES: usize = CRYPTO_GROUP_SIZE_BYTES + CRYPTO_PUBLIC_KEY_SIZE_BYTES; -#[cfg(feature = "crypto_openssl")] -fn crypto_spake2_new() -> Result, Error> { - Ok(Box::new(CryptoOpenSSL::new()?)) -} - -#[cfg(feature = "crypto_mbedtls")] -fn crypto_spake2_new() -> Result, Error> { - Ok(Box::new(CryptoMbedTLS::new()?)) -} - -#[cfg(feature = "crypto_esp_mbedtls")] -fn crypto_spake2_new() -> Result, Error> { - Ok(Box::new(CryptoEspMbedTls::new()?)) -} - -#[cfg(feature = "crypto_rustcrypto")] -fn crypto_spake2_new() -> Result, Error> { - Ok(Box::new(CryptoRustCrypto::new()?)) +fn crypto_spake2_new() -> Result { + CryptoSpake2::new() } impl Default for Spake2P { @@ -129,13 +101,13 @@ pub enum VerifierOption { } impl VerifierData { - pub fn new_with_pw(pw: u32) -> Self { + pub fn new_with_pw(pw: u32, rand: Rand) -> Self { let mut s = Self { salt: [0; MAX_SALT_SIZE_BYTES], count: sys::SPAKE2_ITERATION_COUNT, data: VerifierOption::Password(pw), }; - rand::thread_rng().fill_bytes(&mut s.salt); + rand(&mut s.salt); s } @@ -158,7 +130,7 @@ impl VerifierData { } impl Spake2P { - pub fn new() -> Self { + pub const fn new() -> Self { Spake2P { mode: Spake2Mode::Unknown, context: None, @@ -198,7 +170,7 @@ impl Spake2P { match verifier.data { VerifierOption::Password(pw) => { // Derive w0 and L from the password - let mut w0w1s: [u8; 2 * CRYPTO_W_SIZE_BYTES] = [0; 2 * CRYPTO_W_SIZE_BYTES]; + let mut w0w1s: [u8; 2 * CRYPTO_W_SIZE_BYTES] = [0; (2 * CRYPTO_W_SIZE_BYTES)]; Spake2P::get_w0w1s(pw, verifier.count, &verifier.salt, &mut w0w1s); let w0s_len = w0w1s.len() / 2; @@ -317,7 +289,7 @@ mod tests { 0x4, 0xa1, 0xd2, 0xc6, 0x11, 0xf0, 0xbd, 0x36, 0x78, 0x67, 0x79, 0x7b, 0xfe, 0x82, 0x36, 0x0, ]; - let mut w0w1s: [u8; 2 * CRYPTO_W_SIZE_BYTES] = [0; 2 * CRYPTO_W_SIZE_BYTES]; + let mut w0w1s: [u8; 2 * CRYPTO_W_SIZE_BYTES] = [0; (2 * CRYPTO_W_SIZE_BYTES)]; Spake2P::get_w0w1s(123456, 2000, &salt, &mut w0w1s); assert_eq!( w0w1s, diff --git a/matter/src/secure_channel/status_report.rs b/matter/src/secure_channel/status_report.rs index 050cd5ba..477bcfae 100644 --- a/matter/src/secure_channel/status_report.rs +++ b/matter/src/secure_channel/status_report.rs @@ -46,7 +46,7 @@ pub fn create_status_report( proto_code: u16, proto_data: Option<&[u8]>, ) -> Result<(), Error> { - proto_tx.set_proto_id(PROTO_ID_SECURE_CHANNEL as u16); + proto_tx.set_proto_id(PROTO_ID_SECURE_CHANNEL); proto_tx.set_proto_opcode(OpCode::StatusReport as u8); let wb = proto_tx.get_writebuf()?; wb.le_u16(general_code as u16)?; diff --git a/matter/src/sys/mod.rs b/matter/src/sys/mod.rs index e8e59cbd..9b5219ef 100644 --- a/matter/src/sys/mod.rs +++ b/matter/src/sys/mod.rs @@ -25,7 +25,8 @@ mod sys_linux; #[cfg(target_os = "linux")] pub use self::sys_linux::*; -#[cfg(any(target_os = "macos", target_os = "linux"))] -mod posix; -#[cfg(any(target_os = "macos", target_os = "linux"))] -pub use self::posix::*; +pub const SPAKE2_ITERATION_COUNT: u32 = 2000; + +// The Packet Pool that is allocated from. POSIX systems can use +// higher values unlike embedded systems +pub const MAX_PACKET_POOL_SIZE: usize = 25; diff --git a/matter/src/sys/posix.rs b/matter/src/sys/posix.rs deleted file mode 100644 index 2736e516..00000000 --- a/matter/src/sys/posix.rs +++ /dev/null @@ -1,96 +0,0 @@ -/* - * - * Copyright (c) 2020-2022 Project CHIP Authors - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -use std::{ - convert::TryInto, - fs::{remove_file, DirBuilder, File}, - io::{Read, Write}, - sync::{Arc, Mutex, Once}, -}; - -use crate::error::Error; - -pub const SPAKE2_ITERATION_COUNT: u32 = 2000; - -// The Packet Pool that is allocated from. POSIX systems can use -// higher values unlike embedded systems -pub const MAX_PACKET_POOL_SIZE: usize = 25; - -pub struct Psm {} - -static mut G_PSM: Option>> = None; -static INIT: Once = Once::new(); - -const PSM_DIR: &str = "/tmp/matter_psm"; - -macro_rules! psm_path { - ($key:ident) => { - format!("{}/{}", PSM_DIR, $key) - }; -} - -impl Psm { - fn new() -> Result { - let result = DirBuilder::new().create(PSM_DIR); - if let Err(e) = result { - if e.kind() != std::io::ErrorKind::AlreadyExists { - return Err(e.into()); - } - } - - Ok(Self {}) - } - - pub fn get() -> Result>, Error> { - unsafe { - INIT.call_once(|| { - G_PSM = Some(Arc::new(Mutex::new(Psm::new().unwrap()))); - }); - Ok(G_PSM.as_ref().ok_or(Error::Invalid)?.clone()) - } - } - - pub fn set_kv_slice(&self, key: &str, val: &[u8]) -> Result<(), Error> { - let mut f = File::create(psm_path!(key))?; - f.write_all(val)?; - Ok(()) - } - - pub fn get_kv_slice(&self, key: &str, val: &mut Vec) -> Result { - let mut f = File::open(psm_path!(key))?; - let len = f.read_to_end(val)?; - Ok(len) - } - - pub fn set_kv_u64(&self, key: &str, val: u64) -> Result<(), Error> { - let mut f = File::create(psm_path!(key))?; - f.write_all(&val.to_be_bytes())?; - Ok(()) - } - - pub fn get_kv_u64(&self, key: &str, val: &mut u64) -> Result<(), Error> { - let mut f = File::open(psm_path!(key))?; - let mut vec = Vec::new(); - let _ = f.read_to_end(&mut vec)?; - *val = u64::from_be_bytes(vec.as_slice().try_into()?); - Ok(()) - } - - pub fn rm(&self, key: &str) { - let _ = remove_file(psm_path!(key)); - } -} diff --git a/matter/src/tlv/parser.rs b/matter/src/tlv/parser.rs index a9b8b87e..f8b9716c 100644 --- a/matter/src/tlv/parser.rs +++ b/matter/src/tlv/parser.rs @@ -18,8 +18,8 @@ use crate::error::Error; use byteorder::{ByteOrder, LittleEndian}; +use core::fmt; use log::{error, info}; -use std::fmt; use super::{TagType, MAX_TAG_INDEX, TAG_MASK, TAG_SHIFT_BITS, TAG_SIZE_MAP, TYPE_MASK}; @@ -318,7 +318,7 @@ impl<'a> PartialEq for TLVElement<'a> { loop { let ours = our_iter.next(); let theirs = their.next(); - if std::mem::discriminant(&ours) != std::mem::discriminant(&theirs) { + if core::mem::discriminant(&ours) != core::mem::discriminant(&theirs) { // One of us reached end of list, but the other didn't, that's a mismatch return false; } @@ -341,8 +341,8 @@ impl<'a> PartialEq for TLVElement<'a> { // Only compare the discriminants in case of array/list/structures, // instead of actual element values. Those will be subsets within this same // list that will get validated anyway - if std::mem::discriminant(&ours.element_type) - != std::mem::discriminant(&theirs.element_type) + if core::mem::discriminant(&ours.element_type) + != core::mem::discriminant(&theirs.element_type) { return false; } @@ -438,6 +438,18 @@ impl<'a> TLVElement<'a> { } } + pub fn str(&self) -> Result<&'a str, Error> { + match self.element_type { + ElementType::Str8l(s) + | ElementType::Utf8l(s) + | ElementType::Str16l(s) + | ElementType::Utf16l(s) => { + Ok(core::str::from_utf8(s).map_err(|_| Error::InvalidData)?) + } + _ => Err(Error::TLVTypeMismatch), + } + } + pub fn bool(&self) -> Result { match self.element_type { ElementType::False => Ok(false), @@ -522,7 +534,7 @@ impl<'a> fmt::Display for TLVElement<'a> { | ElementType::Utf8l(a) | ElementType::Str16l(a) | ElementType::Utf16l(a) => { - if let Ok(s) = std::str::from_utf8(a) { + if let Ok(s) = core::str::from_utf8(a) { write!(f, "len[{}]\"{}\"", s.len(), s) } else { write!(f, "len[{}]{:x?}", a.len(), a) @@ -752,7 +764,7 @@ pub fn print_tlv_list(b: &[u8]) { match a.element_type { ElementType::Struct(_) => { if index < MAX_DEPTH { - println!("{}{}", space[index], a); + info!("{}{}", space[index], a); stack[index] = '}'; index += 1; } else { @@ -761,7 +773,7 @@ pub fn print_tlv_list(b: &[u8]) { } ElementType::Array(_) | ElementType::List(_) => { if index < MAX_DEPTH { - println!("{}{}", space[index], a); + info!("{}{}", space[index], a); stack[index] = ']'; index += 1; } else { @@ -771,19 +783,21 @@ pub fn print_tlv_list(b: &[u8]) { ElementType::EndCnt => { if index > 0 { index -= 1; - println!("{}{}", space[index], stack[index]); + info!("{}{}", space[index], stack[index]); } else { error!("Incorrect TLV List"); } } - _ => println!("{}{}", space[index], a), + _ => info!("{}{}", space[index], a), } } - println!("---------"); + info!("---------"); } #[cfg(test)] mod tests { + use log::info; + use super::{ get_root_node_list, get_root_node_struct, ElementType, Pointer, TLVElement, TLVList, TagType, @@ -1105,7 +1119,7 @@ mod tests { .unwrap() .enter() .unwrap(); - println!("Command list iterator: {:?}", cmd_list_iter); + info!("Command list iterator: {:?}", cmd_list_iter); // This is an array of CommandDataIB, but we'll only use the first element let cmd_data_ib = cmd_list_iter.next().unwrap(); @@ -1203,8 +1217,8 @@ mod tests { Some(a) => { assert_eq!(a.tag_type, verify_matrix[index].0); assert_eq!( - std::mem::discriminant(&a.element_type), - std::mem::discriminant(&verify_matrix[index].1) + core::mem::discriminant(&a.element_type), + core::mem::discriminant(&verify_matrix[index].1) ); } } diff --git a/matter/src/tlv/traits.rs b/matter/src/tlv/traits.rs index 2d3ceddd..c7b5e359 100644 --- a/matter/src/tlv/traits.rs +++ b/matter/src/tlv/traits.rs @@ -17,9 +17,13 @@ use super::{ElementType, TLVContainerIterator, TLVElement, TLVWriter, TagType}; use crate::error::Error; +use alloc::borrow::ToOwned; +use alloc::{string::String, vec::Vec}; +use core::fmt::Debug; use core::slice::Iter; use log::error; -use std::fmt::Debug; + +extern crate alloc; pub trait FromTLV<'a> { fn from_tlv(t: &TLVElement<'a>) -> Result @@ -76,6 +80,15 @@ pub trait ToTLV { fn to_tlv(&self, tw: &mut TLVWriter, tag: TagType) -> Result<(), Error>; } +impl ToTLV for &T +where + T: ToTLV, +{ + fn to_tlv(&self, tw: &mut TLVWriter, tag: TagType) -> Result<(), Error> { + (**self).to_tlv(tw, tag) + } +} + macro_rules! totlv_for { ($($t:ident)*) => { $( @@ -116,10 +129,14 @@ totlv_for!(i8 u8 u16 u32 u64 bool); pub struct UtfStr<'a>(pub &'a [u8]); impl<'a> UtfStr<'a> { - pub fn new(str: &'a [u8]) -> Self { + pub const fn new(str: &'a [u8]) -> Self { Self(str) } + pub fn as_str(&self) -> Result<&str, Error> { + core::str::from_utf8(self.0).map_err(|_| Error::Invalid) + } + pub fn to_string(self) -> Result { String::from_utf8(self.0.to_vec()).map_err(|_| Error::Invalid) } @@ -396,7 +413,7 @@ impl<'a, T> FromTLV<'a> for TLVArray<'a, T> { } impl<'a, T: Debug + ToTLV + FromTLV<'a> + Copy> Debug for TLVArray<'a, T> { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { for i in self.iter() { writeln!(f, "{:?}", i)?; } @@ -442,9 +459,8 @@ mod tests { } #[test] fn test_derive_totlv() { - let mut buf: [u8; 20] = [0; 20]; - let buf_len = buf.len(); - let mut writebuf = WriteBuf::new(&mut buf, buf_len); + let mut buf = [0; 20]; + let mut writebuf = WriteBuf::new(&mut buf); let mut tw = TLVWriter::new(&mut writebuf); let abc = TestDerive { @@ -525,9 +541,8 @@ mod tests { #[test] fn test_derive_totlv_fab_scoped() { - let mut buf: [u8; 20] = [0; 20]; - let buf_len = buf.len(); - let mut writebuf = WriteBuf::new(&mut buf, buf_len); + let mut buf = [0; 20]; + let mut writebuf = WriteBuf::new(&mut buf); let mut tw = TLVWriter::new(&mut writebuf); let abc = TestDeriveFabScoped { a: 20, fab_idx: 3 }; @@ -557,9 +572,8 @@ mod tests { enum_val = TestDeriveEnum::ValueB(10); // Test ToTLV - let mut buf: [u8; 20] = [0; 20]; - let buf_len = buf.len(); - let mut writebuf = WriteBuf::new(&mut buf, buf_len); + let mut buf = [0; 20]; + let mut writebuf = WriteBuf::new(&mut buf); let mut tw = TLVWriter::new(&mut writebuf); enum_val.to_tlv(&mut tw, TagType::Anonymous).unwrap(); diff --git a/matter/src/tlv/writer.rs b/matter/src/tlv/writer.rs index cdf914a2..1db84210 100644 --- a/matter/src/tlv/writer.rs +++ b/matter/src/tlv/writer.rs @@ -50,11 +50,11 @@ enum WriteElementType { } pub struct TLVWriter<'a, 'b> { - buf: &'b mut WriteBuf<'a>, + buf: &'a mut WriteBuf<'b>, } impl<'a, 'b> TLVWriter<'a, 'b> { - pub fn new(buf: &'b mut WriteBuf<'a>) -> Self { + pub fn new(buf: &'a mut WriteBuf<'b>) -> Self { TLVWriter { buf } } @@ -265,7 +265,7 @@ impl<'a, 'b> TLVWriter<'a, 'b> { self.buf.rewind_tail_to(anchor); } - pub fn get_buf<'c>(&'c mut self) -> &'c mut WriteBuf<'a> { + pub fn get_buf(&mut self) -> &mut WriteBuf<'b> { self.buf } } @@ -277,9 +277,8 @@ mod tests { #[test] fn test_write_success() { - let mut buf: [u8; 20] = [0; 20]; - let buf_len = buf.len(); - let mut writebuf = WriteBuf::new(&mut buf, buf_len); + let mut buf = [0; 20]; + let mut writebuf = WriteBuf::new(&mut buf); let mut tw = TLVWriter::new(&mut writebuf); tw.start_struct(TagType::Anonymous).unwrap(); @@ -299,9 +298,8 @@ mod tests { #[test] fn test_write_overflow() { - let mut buf: [u8; 6] = [0; 6]; - let buf_len = buf.len(); - let mut writebuf = WriteBuf::new(&mut buf, buf_len); + let mut buf = [0; 6]; + let mut writebuf = WriteBuf::new(&mut buf); let mut tw = TLVWriter::new(&mut writebuf); tw.u8(TagType::Anonymous, 12).unwrap(); @@ -317,9 +315,8 @@ mod tests { #[test] fn test_put_str8() { - let mut buf: [u8; 20] = [0; 20]; - let buf_len = buf.len(); - let mut writebuf = WriteBuf::new(&mut buf, buf_len); + let mut buf = [0; 20]; + let mut writebuf = WriteBuf::new(&mut buf); let mut tw = TLVWriter::new(&mut writebuf); tw.u8(TagType::Context(1), 13).unwrap(); @@ -334,9 +331,8 @@ mod tests { #[test] fn test_put_str16_as() { - let mut buf: [u8; 20] = [0; 20]; - let buf_len = buf.len(); - let mut writebuf = WriteBuf::new(&mut buf, buf_len); + let mut buf = [0; 20]; + let mut writebuf = WriteBuf::new(&mut buf); let mut tw = TLVWriter::new(&mut writebuf); tw.u8(TagType::Context(1), 13).unwrap(); diff --git a/matter/src/transport/exchange.rs b/matter/src/transport/exchange.rs index ae5711d4..e668a8d5 100644 --- a/matter/src/transport/exchange.rs +++ b/matter/src/transport/exchange.rs @@ -15,43 +15,46 @@ * limitations under the License. */ -use boxslab::{BoxSlab, Slab}; use colored::*; +use core::any::Any; +use core::fmt; +use core::time::Duration; use log::{error, info, trace}; -use std::any::Any; -use std::fmt; -use std::time::SystemTime; use crate::error::Error; use crate::secure_channel; +use crate::secure_channel::case::CaseSession; +use crate::utils::epoch::Epoch; +use crate::utils::rand::Rand; use heapless::LinearMap; -use super::packet::PacketPool; use super::session::CloneData; use super::{mrp::ReliableMessage, packet::Packet, session::SessionHandle, session::SessionMgr}; pub struct ExchangeCtx<'a> { pub exch: &'a mut Exchange, pub sess: SessionHandle<'a>, + pub epoch: Epoch, } -#[derive(Debug, PartialEq, Eq, Copy, Clone)] +impl<'a> ExchangeCtx<'a> { + pub fn send(&mut self, tx: &mut Packet) -> Result<(), Error> { + self.exch.send(tx, &mut self.sess) + } +} + +#[derive(Debug, PartialEq, Eq, Copy, Clone, Default)] pub enum Role { + #[default] Initiator = 0, Responder = 1, } -impl Default for Role { - fn default() -> Self { - Role::Initiator - } -} - -/// State of the exchange -#[derive(Debug, PartialEq)] +#[derive(Debug, PartialEq, Default)] enum State { /// The exchange is open and active + #[default] Open, /// The exchange is closed, but keys are active since retransmissions/acks may be pending Close, @@ -59,28 +62,17 @@ enum State { Terminate, } -impl Default for State { - fn default() -> Self { - State::Open - } -} - // Instead of just doing an Option<>, we create some special handling // where the commonly used higher layer data store does't have to do a Box -#[derive(Debug)] +#[derive(Default)] pub enum DataOption { - Boxed(Box), - Time(SystemTime), + CaseSession(CaseSession), + Time(Duration), + #[default] None, } -impl Default for DataOption { - fn default() -> Self { - DataOption::None - } -} - -#[derive(Debug, Default)] +#[derive(Default)] pub struct Exchange { id: u16, sess_idx: usize, @@ -136,48 +128,48 @@ impl Exchange { matches!(self.data, DataOption::None) } - pub fn set_data_boxed(&mut self, data: Box) { - self.data = DataOption::Boxed(data); + pub fn set_case_session(&mut self, session: CaseSession) { + self.data = DataOption::CaseSession(session); } - pub fn clear_data_boxed(&mut self) { + pub fn clear_data(&mut self) { self.data = DataOption::None; } - pub fn get_data_boxed(&mut self) -> Option<&mut T> { - if let DataOption::Boxed(a) = &mut self.data { - a.downcast_mut::() + pub fn get_case_session(&mut self) -> Option<&mut CaseSession> { + if let DataOption::CaseSession(session) = &mut self.data { + Some(session) } else { None } } - pub fn take_data_boxed(&mut self) -> Option> { - let old = std::mem::replace(&mut self.data, DataOption::None); - if let DataOption::Boxed(d) = old { - d.downcast::().ok() + pub fn take_case_session(&mut self) -> Option { + let old = core::mem::replace(&mut self.data, DataOption::None); + if let DataOption::CaseSession(session) = old { + Some(session) } else { self.data = old; None } } - pub fn set_data_time(&mut self, expiry_ts: Option) { + pub fn set_data_time(&mut self, expiry_ts: Option) { if let Some(t) = expiry_ts { self.data = DataOption::Time(t); } } - pub fn get_data_time(&self) -> Option { + pub fn get_data_time(&self) -> Option { match self.data { DataOption::Time(t) => Some(t), _ => None, } } - pub fn send( + pub(crate) fn send( &mut self, - mut proto_tx: BoxSlab, + tx: &mut Packet, session: &mut SessionHandle, ) -> Result<(), Error> { if self.state == State::Terminate { @@ -185,22 +177,22 @@ impl Exchange { return Ok(()); } - trace!("payload: {:x?}", proto_tx.as_borrow_slice()); + trace!("payload: {:x?}", tx.as_mut_slice()); info!( "{} with proto id: {} opcode: {}", "Sending".blue(), - proto_tx.get_proto_id(), - proto_tx.get_proto_opcode(), + tx.get_proto_id(), + tx.get_proto_opcode(), ); - proto_tx.proto.exch_id = self.id; + tx.proto.exch_id = self.id; if self.role == Role::Initiator { - proto_tx.proto.set_initiator(); + tx.proto.set_initiator(); } - session.pre_send(&mut proto_tx)?; - self.mrp.pre_send(&mut proto_tx)?; - session.send(proto_tx) + session.pre_send(tx)?; + self.mrp.pre_send(tx)?; + session.send(tx) } } @@ -208,8 +200,8 @@ impl fmt::Display for Exchange { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { write!( f, - "exch_id: {:?}, sess_index: {}, role: {:?}, data: {:?}, mrp: {:?}, state: {:?}", - self.id, self.sess_idx, self.role, self.data, self.mrp, self.state + "exch_id: {:?}, sess_index: {}, role: {:?}, mrp: {:?}, state: {:?}", + self.id, self.sess_idx, self.role, self.mrp, self.state ) } } @@ -232,20 +224,21 @@ pub fn get_complementary_role(is_initiator: bool) -> Role { const MAX_EXCHANGES: usize = 8; -#[derive(Default)] pub struct ExchangeMgr { // keys: exch-id exchanges: LinearMap, sess_mgr: SessionMgr, + epoch: Epoch, } pub const MAX_MRP_ENTRIES: usize = 4; impl ExchangeMgr { - pub fn new(sess_mgr: SessionMgr) -> Self { + pub fn new(epoch: Epoch, rand: Rand) -> Self { Self { - sess_mgr, - exchanges: Default::default(), + sess_mgr: SessionMgr::new(epoch, rand), + exchanges: LinearMap::new(), + epoch, } } @@ -300,45 +293,33 @@ impl ExchangeMgr { } /// The Exchange Mgr receive is like a big processing function - pub fn recv(&mut self) -> Result, ExchangeCtx)>, Error> { + pub fn recv(&mut self, rx: &mut Packet) -> Result, Error> { // Get the session - let (mut proto_rx, index) = self.sess_mgr.recv()?; - - let index = if let Some(s) = index { - s - } else { - // The sessions were full, evict one session, and re-perform post-recv - let evict_index = self.sess_mgr.get_lru(); - self.evict_session(evict_index)?; - info!("Reattempting session creation"); - self.sess_mgr.post_recv(&proto_rx)?.ok_or(Error::Invalid)? - }; + let index = self.sess_mgr.post_recv(rx)?; let mut session = self.sess_mgr.get_session_handle(index); // Decrypt the message - session.recv(&mut proto_rx)?; + session.recv(self.epoch, rx)?; // Get the exchange let exch = ExchangeMgr::_get( &mut self.exchanges, index, - proto_rx.proto.exch_id, - get_complementary_role(proto_rx.proto.is_initiator()), + rx.proto.exch_id, + get_complementary_role(rx.proto.is_initiator()), // We create a new exchange, only if the peer is the initiator - proto_rx.proto.is_initiator(), + rx.proto.is_initiator(), )?; // Message Reliability Protocol - exch.mrp.recv(&proto_rx)?; + exch.mrp.recv(rx, self.epoch)?; if exch.is_state_open() { - Ok(Some(( - proto_rx, - ExchangeCtx { - exch, - sess: session, - }, - ))) + Ok(Some(ExchangeCtx { + exch, + sess: session, + epoch: self.epoch, + })) } else { // Instead of an error, we send None here, because it is likely that // we just processed an acknowledgement that cleared the exchange @@ -346,11 +327,11 @@ impl ExchangeMgr { } } - pub fn send(&mut self, exch_id: u16, proto_tx: BoxSlab) -> Result<(), Error> { + pub fn send(&mut self, exch_id: u16, tx: &mut Packet) -> Result<(), Error> { let exchange = ExchangeMgr::_get_with_id(&mut self.exchanges, exch_id).ok_or(Error::NoExchange)?; let mut session = self.sess_mgr.get_session_handle(exchange.sess_idx); - exchange.send(proto_tx, &mut session) + exchange.send(tx, &mut session) } pub fn purge(&mut self) { @@ -366,70 +347,66 @@ impl ExchangeMgr { } } - pub fn pending_acks(&mut self, expired_entries: &mut LinearMap) { - for (exch_id, exchange) in self.exchanges.iter() { - if exchange.mrp.is_ack_ready() { - expired_entries.insert(*exch_id, ()).unwrap(); + pub fn pending_ack(&mut self) -> Option { + self.exchanges + .iter() + .find(|(_, exchange)| exchange.mrp.is_ack_ready(self.epoch)) + .map(|(exch_id, _)| *exch_id) + } + + pub fn evict_session(&mut self, tx: &mut Packet) -> Result { + if let Some(index) = self.sess_mgr.get_session_for_eviction() { + info!("Sessions full, vacating session with index: {}", index); + // If we enter here, we have an LRU session that needs to be reclaimed + // As per the spec, we need to send a CLOSE here + + let mut session = self.sess_mgr.get_session_handle(index); + secure_channel::common::create_sc_status_report( + tx, + secure_channel::common::SCStatusCodes::CloseSession, + None, + )?; + + if let Some((_, exchange)) = + self.exchanges.iter_mut().find(|(_, e)| e.sess_idx == index) + { + // Send Close_session on this exchange, and then close the session + // Should this be done for all exchanges? + error!("Sending Close Session"); + exchange.send(tx, &mut session)?; + // TODO: This wouldn't actually send it out, because 'transport' isn't owned yet. } - } - } - - pub fn evict_session(&mut self, index: usize) -> Result<(), Error> { - info!("Sessions full, vacating session with index: {}", index); - // If we enter here, we have an LRU session that needs to be reclaimed - // As per the spec, we need to send a CLOSE here - - let mut session = self.sess_mgr.get_session_handle(index); - let mut tx = Slab::::try_new(Packet::new_tx()?).ok_or(Error::NoSpace)?; - secure_channel::common::create_sc_status_report( - &mut tx, - secure_channel::common::SCStatusCodes::CloseSession, - None, - )?; - if let Some((_, exchange)) = self.exchanges.iter_mut().find(|(_, e)| e.sess_idx == index) { - // Send Close_session on this exchange, and then close the session - // Should this be done for all exchanges? - error!("Sending Close Session"); - exchange.send(tx, &mut session)?; - // TODO: This wouldn't actually send it out, because 'transport' isn't owned yet. - } + let remove_exchanges: heapless::Vec = self + .exchanges + .iter() + .filter_map(|(eid, e)| { + if e.sess_idx == index { + Some(*eid) + } else { + None + } + }) + .collect(); + info!( + "Terminating the following exchanges: {:?}", + remove_exchanges + ); + for exch_id in remove_exchanges { + // Remove from exchange list + self.exchanges.remove(&exch_id); + } + self.sess_mgr.remove(index); - let remove_exchanges: Vec = self - .exchanges - .iter() - .filter_map(|(eid, e)| { - if e.sess_idx == index { - Some(*eid) - } else { - None - } - }) - .collect(); - info!( - "Terminating the following exchanges: {:?}", - remove_exchanges - ); - for exch_id in remove_exchanges { - // Remove from exchange list - self.exchanges.remove(&exch_id); + Ok(true) + } else { + Ok(false) } - self.sess_mgr.remove(index); - Ok(()) } pub fn add_session(&mut self, clone_data: &CloneData) -> Result { - let sess_idx = match self.sess_mgr.clone_session(clone_data) { - Ok(idx) => idx, - Err(Error::NoSpace) => { - let evict_index = self.sess_mgr.get_lru(); - self.evict_session(evict_index)?; - self.sess_mgr.clone_session(clone_data)? - } - Err(e) => { - return Err(e); - } - }; + let sess_idx = self.sess_mgr.clone_session(clone_data)?; + Ok(self.sess_mgr.get_session_handle(sess_idx)) } } @@ -449,12 +426,16 @@ impl fmt::Display for ExchangeMgr { #[cfg(test)] #[allow(clippy::bool_assert_comparison)] mod tests { - use crate::{ error::Error, transport::{ - network::{Address, NetworkInterface}, - session::{CloneData, SessionMgr, SessionMode, MAX_SESSIONS}, + network::Address, + packet::Packet, + session::{CloneData, SessionMode, MAX_SESSIONS}, + }, + utils::{ + epoch::{dummy_epoch, sys_epoch}, + rand::dummy_rand, }, }; @@ -462,8 +443,7 @@ mod tests { #[test] fn test_purge() { - let sess_mgr = SessionMgr::new(); - let mut mgr = ExchangeMgr::new(sess_mgr); + let mut mgr = ExchangeMgr::new(dummy_epoch, dummy_rand); let _ = ExchangeMgr::_get(&mut mgr.exchanges, 1, 2, Role::Responder, true).unwrap(); let _ = ExchangeMgr::_get(&mut mgr.exchanges, 1, 3, Role::Responder, true).unwrap(); @@ -519,33 +499,13 @@ mod tests { } } - pub struct DummyNetwork; - impl DummyNetwork { - pub fn new() -> Self { - Self {} - } - } - - impl NetworkInterface for DummyNetwork { - fn recv(&self, _in_buf: &mut [u8]) -> Result<(usize, Address), Error> { - Ok((0, Address::default())) - } - - fn send(&self, _out_buf: &[u8], _addr: Address) -> Result { - Ok(0) - } - } - #[test] /// We purposefuly overflow the sessions /// and when the overflow happens, we confirm that /// - The sessions are evicted in LRU /// - The exchanges associated with those sessions are evicted too fn test_sess_evict() { - let mut sess_mgr = SessionMgr::new(); - let transport = Box::new(DummyNetwork::new()); - sess_mgr.add_network_interface(transport).unwrap(); - let mut mgr = ExchangeMgr::new(sess_mgr); + let mut mgr = ExchangeMgr::new(sys_epoch, dummy_rand); // TODO fill_sessions(&mut mgr, MAX_SESSIONS + 1); // Sessions are now full from local session id 1 to 16 @@ -568,6 +528,14 @@ mod tests { for i in 1..(MAX_SESSIONS + 1) { // Now purposefully overflow the sessions by adding another session + let result = mgr.add_session(&get_clone_data(new_peer_sess_id, new_local_sess_id)); + assert!(matches!(result, Err(Error::NoSpace))); + + let mut buf = [0; 1500]; + let tx = &mut Packet::new_tx(&mut buf); + let evicted = mgr.evict_session(tx).unwrap(); + assert!(evicted); + let session = mgr .add_session(&get_clone_data(new_peer_sess_id, new_local_sess_id)) .unwrap(); diff --git a/matter/src/transport/mgr.rs b/matter/src/transport/mgr.rs index 76c506c4..349cfdee 100644 --- a/matter/src/transport/mgr.rs +++ b/matter/src/transport/mgr.rs @@ -15,161 +15,210 @@ * limitations under the License. */ -use async_channel::Receiver; -use boxslab::{BoxSlab, Slab}; -use heapless::LinearMap; -use log::{debug, error, info}; +use core::borrow::Borrow; +use core::cell::RefCell; + +use log::info; use crate::error::*; +use crate::fabric::FabricMgr; +use crate::mdns::MdnsMgr; +use crate::secure_channel::pake::PaseMgr; +use crate::secure_channel::common::PROTO_ID_SECURE_CHANNEL; +use crate::secure_channel::core::SecureChannel; use crate::transport::mrp::ReliableMessage; -use crate::transport::packet::PacketPool; -use crate::transport::{exchange, packet::Packet, proto_demux, queue, session, udp}; - -use super::proto_demux::ProtoCtx; -use super::queue::Msg; +use crate::transport::{exchange, packet::Packet}; +use crate::utils::epoch::Epoch; +use crate::utils::rand::Rand; + +use super::proto_ctx::ProtoCtx; + +#[derive(Copy, Clone, Eq, PartialEq)] +enum RecvState { + New, + OpenExchange, + EvictSession, + Ack, +} -pub struct Mgr { - exch_mgr: exchange::ExchangeMgr, - proto_demux: proto_demux::ProtoDemux, - rx_q: Receiver, +pub enum RecvAction<'r, 'p> { + Send(&'r [u8]), + Interact(ProtoCtx<'r, 'p>), } -impl Mgr { - pub fn new() -> Result { - let mut sess_mgr = session::SessionMgr::new(); - let udp_transport = Box::new(udp::UdpListener::new()?); - sess_mgr.add_network_interface(udp_transport)?; - Ok(Mgr { - proto_demux: proto_demux::ProtoDemux::new(), - exch_mgr: exchange::ExchangeMgr::new(sess_mgr), - rx_q: queue::WorkQ::init()?, - }) - } +pub struct RecvCompletion<'r, 'a, 'p> { + mgr: &'r mut TransportMgr<'a>, + rx: &'r mut Packet<'p>, + tx: &'r mut Packet<'p>, + state: RecvState, +} - // Allows registration of different protocols with the Transport/Protocol Demux - pub fn register_protocol( - &mut self, - proto_id_handle: Box, - ) -> Result<(), Error> { - self.proto_demux.register(proto_id_handle) - } +impl<'r, 'a, 'p> RecvCompletion<'r, 'a, 'p> { + pub fn next_action(&mut self) -> Result>, Error> { + loop { + // Polonius will remove the need for unsafe one day + let this = unsafe { (self as *mut RecvCompletion).as_mut().unwrap() }; - fn send_to_exchange( - &mut self, - exch_id: u16, - proto_tx: BoxSlab, - ) -> Result<(), Error> { - self.exch_mgr.send(exch_id, proto_tx) + if let Some(action) = this.maybe_next_action()? { + return Ok(action); + } + } } - fn handle_rxtx(&mut self) -> Result<(), Error> { - let result = self.exch_mgr.recv().map_err(|e| { - error!("Error in recv: {:?}", e); - e - })?; + fn maybe_next_action(&mut self) -> Result>>, Error> { + self.mgr.exch_mgr.purge(); - if result.is_none() { - // Nothing to process, return quietly - return Ok(()); - } - // result contains something worth processing, we can safely unwrap - // as we already checked for none above - let (rx, exch_ctx) = result.unwrap(); - - debug!("Exchange is {:?}", exch_ctx.exch); - let tx = Self::new_tx()?; - - let mut proto_ctx = ProtoCtx::new(exch_ctx, rx, tx); - // Proto Dispatch - match self.proto_demux.handle(&mut proto_ctx) { - Ok(r) => { - if let proto_demux::ResponseRequired::No = r { - // We need to send the Ack if reliability is enabled, in this case - return Ok(()); + match self.state { + RecvState::New => { + self.mgr.exch_mgr.get_sess_mgr().decode(self.rx)?; + self.state = RecvState::OpenExchange; + Ok(None) + } + RecvState::OpenExchange => match self.mgr.exch_mgr.recv(self.rx) { + Ok(Some(exch_ctx)) => { + if self.rx.get_proto_id() == PROTO_ID_SECURE_CHANNEL { + let mut proto_ctx = ProtoCtx::new(exch_ctx, self.rx, self.tx); + + if self.mgr.secure_channel.handle(&mut proto_ctx)? { + proto_ctx.send()?; + + self.state = RecvState::Ack; + Ok(Some(Some(RecvAction::Send(self.tx.as_slice())))) + } else { + self.state = RecvState::Ack; + Ok(None) + } + } else { + let proto_ctx = ProtoCtx::new(exch_ctx, self.rx, self.tx); + self.state = RecvState::Ack; + + Ok(Some(Some(RecvAction::Interact(proto_ctx)))) + } + } + Ok(None) => { + self.state = RecvState::Ack; + Ok(None) } + Err(Error::NoSpace) => { + self.state = RecvState::EvictSession; + Ok(None) + } + Err(err) => Err(err), + }, + RecvState::EvictSession => { + self.mgr.exch_mgr.evict_session(self.tx)?; + self.state = RecvState::OpenExchange; + Ok(Some(Some(RecvAction::Send(self.tx.as_slice())))) } - Err(e) => { - error!("Error in proto_demux {:?}", e); - return Err(e); + RecvState::Ack => { + if let Some(exch_id) = self.mgr.exch_mgr.pending_ack() { + info!("Sending MRP Standalone ACK for exch {}", exch_id); + + ReliableMessage::prepare_ack(exch_id, self.tx); + + self.mgr.exch_mgr.send(exch_id, self.tx)?; + Ok(Some(Some(RecvAction::Send(self.tx.as_slice())))) + } else { + Ok(Some(None)) + } } } + } +} - let ProtoCtx { - exch_ctx, - rx: _, - tx, - } = proto_ctx; +#[derive(Copy, Clone, Eq, PartialEq)] +enum NotifyState {} - // tx_ctx now contains the response payload, send the packet - let exch_id = exch_ctx.exch.get_id(); - self.send_to_exchange(exch_id, tx).map_err(|e| { - error!("Error in sending msg {:?}", e); - e - })?; +pub enum NotifyAction<'r, 'p> { + Send(&'r [u8]), + Notify(ProtoCtx<'r, 'p>), +} - Ok(()) - } +pub struct NotifyCompletion<'r, 'a, 'p> { + // TODO + _mgr: &'r mut TransportMgr<'a>, + _rx: &'r mut Packet<'p>, + _tx: &'r mut Packet<'p>, + _state: NotifyState, +} - fn handle_queue_msgs(&mut self) -> Result<(), Error> { - if let Ok(msg) = self.rx_q.try_recv() { - match msg { - Msg::NewSession(clone_data) => { - // If a new session was created, add it - let _ = self - .exch_mgr - .add_session(&clone_data) - .map_err(|e| error!("Error adding new session {:?}", e)); - } - _ => { - error!("Queue Message Type not yet handled {:?}", msg); - } +impl<'r, 'a, 'p> NotifyCompletion<'r, 'a, 'p> { + pub fn next_action(&mut self) -> Result>, Error> { + loop { + // Polonius will remove the need for unsafe one day + let this = unsafe { (self as *mut NotifyCompletion).as_mut().unwrap() }; + + if let Some(action) = this.maybe_next_action()? { + return Ok(action); } } - Ok(()) } - pub fn start(&mut self) -> Result<(), Error> { - loop { - // Handle network operations - if self.handle_rxtx().is_err() { - error!("Error in handle_rxtx"); - continue; - } + fn maybe_next_action(&mut self) -> Result>>, Error> { + Ok(Some(None)) // TODO: Future + } +} - if self.handle_queue_msgs().is_err() { - error!("Error in handle_queue_msg"); - continue; - } +pub struct TransportMgr<'a> { + exch_mgr: exchange::ExchangeMgr, + secure_channel: SecureChannel<'a>, +} - // Handle any pending acknowledgement send - let mut acks_to_send: LinearMap = - LinearMap::new(); - self.exch_mgr.pending_acks(&mut acks_to_send); - for exch_id in acks_to_send.keys() { - info!("Sending MRP Standalone ACK for exch {}", exch_id); - let mut proto_tx = match Self::new_tx() { - Ok(p) => p, - Err(e) => { - error!("Error creating proto_tx {:?}", e); - break; - } - }; - ReliableMessage::prepare_ack(*exch_id, &mut proto_tx); - if let Err(e) = self.send_to_exchange(*exch_id, proto_tx) { - error!("Error in sending Ack {:?}", e); - } - } +impl<'a> TransportMgr<'a> { + pub fn new< + T: Borrow> + Borrow> + Borrow + Borrow, + >( + matter: &'a T, + mdns_mgr: &'a RefCell>, + ) -> Self { + Self::wrap( + SecureChannel::new(matter.borrow(), matter.borrow(), mdns_mgr, *matter.borrow()), + *matter.borrow(), + *matter.borrow(), + ) + } - // Handle exchange purging - // This need not be done in each turn of the loop, maybe once in 5 times or so? - self.exch_mgr.purge(); + pub fn wrap(secure_channel: SecureChannel<'a>, epoch: Epoch, rand: Rand) -> Self { + Self { + exch_mgr: exchange::ExchangeMgr::new(epoch, rand), + secure_channel, + } + } - info!("Exchange Mgr: {}", self.exch_mgr); + pub fn recv<'r, 'p>( + &'r mut self, + rx: &'r mut Packet<'p>, + tx: &'r mut Packet<'p>, + ) -> RecvCompletion<'r, 'a, 'p> { + RecvCompletion { + mgr: self, + rx, + tx, + state: RecvState::New, } } - fn new_tx() -> Result, Error> { - Slab::::try_new(Packet::new_tx()?).ok_or(Error::PacketPoolExhaust) + pub fn notify(&mut self, _tx: &mut Packet) -> Result { + Ok(false) } + + // async fn handle_queue_msgs(&mut self) -> Result<(), Error> { + // if let Ok(msg) = self.rx_q.try_recv() { + // match msg { + // Msg::NewSession(clone_data) => { + // // If a new session was created, add it + // let _ = self + // .exch_mgr + // .add_session(&clone_data) + // .await + // .map_err(|e| error!("Error adding new session {:?}", e)); + // } + // _ => { + // error!("Queue Message Type not yet handled {:?}", msg); + // } + // } + // } + // Ok(()) + // } } diff --git a/matter/src/transport/mod.rs b/matter/src/transport/mod.rs index b3a2545c..43acccd7 100644 --- a/matter/src/transport/mod.rs +++ b/matter/src/transport/mod.rs @@ -22,7 +22,7 @@ pub mod mrp; pub mod network; pub mod packet; pub mod plain_hdr; -pub mod proto_demux; +pub mod proto_ctx; pub mod proto_hdr; pub mod queue; pub mod session; diff --git a/matter/src/transport/mrp.rs b/matter/src/transport/mrp.rs index 22cf9ad6..2213a524 100644 --- a/matter/src/transport/mrp.rs +++ b/matter/src/transport/mrp.rs @@ -15,8 +15,8 @@ * limitations under the License. */ -use std::time::Duration; -use std::time::SystemTime; +use crate::utils::epoch::Epoch; +use core::time::Duration; use crate::{error::*, secure_channel, transport::packet::Packet}; use log::error; @@ -46,13 +46,13 @@ pub struct AckEntry { // The msg counter that we should acknowledge msg_ctr: u32, // The max time after which this entry must be ACK - ack_timeout: SystemTime, + ack_timeout: Duration, } impl AckEntry { - pub fn new(msg_ctr: u32) -> Result { + pub fn new(msg_ctr: u32, epoch: Epoch) -> Result { if let Some(ack_timeout) = - SystemTime::now().checked_add(Duration::from_millis(MRP_STANDALONE_ACK_TIMEOUT)) + epoch().checked_add(Duration::from_millis(MRP_STANDALONE_ACK_TIMEOUT)) { Ok(Self { msg_ctr, @@ -67,8 +67,8 @@ impl AckEntry { self.msg_ctr } - pub fn has_timed_out(&self) -> bool { - self.ack_timeout > SystemTime::now() + pub fn has_timed_out(&self, epoch: Epoch) -> bool { + self.ack_timeout > epoch() } } @@ -90,10 +90,10 @@ impl ReliableMessage { } // Check any pending acknowledgements / retransmissions and take action - pub fn is_ack_ready(&self) -> bool { + pub fn is_ack_ready(&self, epoch: Epoch) -> bool { // Acknowledgements if let Some(ack_entry) = self.ack { - ack_entry.has_timed_out() + ack_entry.has_timed_out(epoch) } else { false } @@ -132,7 +132,7 @@ impl ReliableMessage { * - there can be only one pending retransmission per exchange (so this is per-exchange) * - duplicate detection should happen per session (obviously), so that part is per-session */ - pub fn recv(&mut self, proto_rx: &Packet) -> Result<(), Error> { + pub fn recv(&mut self, proto_rx: &Packet, epoch: Epoch) -> Result<(), Error> { if proto_rx.proto.is_ack() { // Handle received Acks let ack_msg_ctr = proto_rx.proto.get_ack_msg_ctr().ok_or(Error::Invalid)?; @@ -153,7 +153,7 @@ impl ReliableMessage { return Err(Error::Invalid); } - self.ack = Some(AckEntry::new(proto_rx.plain.ctr)?); + self.ack = Some(AckEntry::new(proto_rx.plain.ctr, epoch)?); } Ok(()) } diff --git a/matter/src/transport/network.rs b/matter/src/transport/network.rs index 5b398ca4..91645de6 100644 --- a/matter/src/transport/network.rs +++ b/matter/src/transport/network.rs @@ -15,12 +15,8 @@ * limitations under the License. */ -use std::{ - fmt::{Debug, Display}, - net::{IpAddr, Ipv4Addr, SocketAddr}, -}; - -use crate::error::Error; +use core::fmt::{Debug, Display}; +use std::net::{IpAddr, Ipv4Addr, SocketAddr}; #[derive(PartialEq, Copy, Clone)] pub enum Address { @@ -34,7 +30,7 @@ impl Default for Address { } impl Display for Address { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { match self { Address::Udp(addr) => writeln!(f, "{}", addr), } @@ -42,14 +38,9 @@ impl Display for Address { } impl Debug for Address { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { match self { Address::Udp(addr) => writeln!(f, "{}", addr), } } } - -pub trait NetworkInterface { - fn recv(&self, in_buf: &mut [u8]) -> Result<(usize, Address), Error>; - fn send(&self, out_buf: &[u8], addr: Address) -> Result; -} diff --git a/matter/src/transport/packet.rs b/matter/src/transport/packet.rs index 18af1b59..e39ac1c9 100644 --- a/matter/src/transport/packet.rs +++ b/matter/src/transport/packet.rs @@ -99,46 +99,46 @@ pub struct Packet<'a> { pub proto: ProtoHdr, pub peer: Address, data: Direction<'a>, - buffer_index: usize, } impl<'a> Packet<'a> { const HDR_RESERVE: usize = plain_hdr::max_plain_hdr_len() + proto_hdr::max_proto_hdr_len(); - pub fn new_rx() -> Result { - let (buffer_index, buffer) = BufferPool::alloc().ok_or(Error::NoSpace)?; - let buf_len = buffer.len(); - Ok(Self { + pub fn new_rx(buf: &'a mut [u8]) -> Self { + Self { plain: Default::default(), proto: Default::default(), - buffer_index, peer: Address::default(), - data: Direction::Rx(ParseBuf::new(buffer, buf_len), RxState::Uninit), - }) + data: Direction::Rx(ParseBuf::new(buf), RxState::Uninit), + } } - pub fn new_tx() -> Result { - let (buffer_index, buffer) = BufferPool::alloc().ok_or(Error::NoSpace)?; - let buf_len = buffer.len(); + pub fn new_tx(buf: &'a mut [u8]) -> Self { + let mut wb = WriteBuf::new(buf); + wb.reserve(Packet::HDR_RESERVE).unwrap(); - let mut wb = WriteBuf::new(buffer, buf_len); - wb.reserve(Packet::HDR_RESERVE)?; + // Reliability on by default + let mut proto: ProtoHdr = Default::default(); + proto.set_reliable(); - let mut p = Self { + Self { plain: Default::default(), - proto: Default::default(), - buffer_index, + proto, peer: Address::default(), data: Direction::Tx(wb), - }; - // Reliability on by default - p.proto.set_reliable(); - Ok(p) + } + } + + pub fn as_slice(&self) -> &[u8] { + match &self.data { + Direction::Rx(pb, _) => pb.as_slice(), + Direction::Tx(wb) => wb.as_slice(), + } } - pub fn as_borrow_slice(&mut self) -> &mut [u8] { + pub fn as_mut_slice(&mut self) -> &mut [u8] { match &mut self.data { - Direction::Rx(pb, _) => pb.as_borrow_slice(), + Direction::Rx(pb, _) => pb.as_mut_slice(), Direction::Tx(wb) => wb.as_mut_slice(), } } @@ -229,11 +229,4 @@ impl<'a> Packet<'a> { } } -impl<'a> Drop for Packet<'a> { - fn drop(&mut self) { - BufferPool::free(self.buffer_index); - trace!("Dropping Packet......"); - } -} - box_slab!(PacketPool, Packet<'static>, MAX_PACKET_POOL_SIZE); diff --git a/matter/src/transport/plain_hdr.rs b/matter/src/transport/plain_hdr.rs index 5e54cd1f..e51ddaf0 100644 --- a/matter/src/transport/plain_hdr.rs +++ b/matter/src/transport/plain_hdr.rs @@ -21,18 +21,13 @@ use crate::utils::writebuf::WriteBuf; use bitflags::bitflags; use log::info; -#[derive(Debug, PartialEq)] +#[derive(Debug, PartialEq, Default)] pub enum SessionType { + #[default] None, Encrypted, } -impl Default for SessionType { - fn default() -> SessionType { - SessionType::None - } -} - bitflags! { #[derive(Default)] pub struct MsgFlags: u8 { diff --git a/matter/src/transport/proto_ctx.rs b/matter/src/transport/proto_ctx.rs new file mode 100644 index 00000000..747a1e6a --- /dev/null +++ b/matter/src/transport/proto_ctx.rs @@ -0,0 +1,43 @@ +/* + * + * Copyright (c) 2020-2022 Project CHIP Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +use crate::error::Error; + +use super::exchange::ExchangeCtx; +use super::packet::Packet; + +/// This is the context in which a receive packet is being processed +pub struct ProtoCtx<'a, 'b> { + /// This is the exchange context, that includes the exchange and the session + pub exch_ctx: ExchangeCtx<'a>, + /// This is the received buffer for this transaction + pub rx: &'a Packet<'b>, + /// This is the transmit buffer for this transaction + pub tx: &'a mut Packet<'b>, +} + +impl<'a, 'b> ProtoCtx<'a, 'b> { + pub fn new(exch_ctx: ExchangeCtx<'a>, rx: &'a Packet<'b>, tx: &'a mut Packet<'b>) -> Self { + Self { exch_ctx, rx, tx } + } + + pub fn send(&mut self) -> Result<&[u8], Error> { + self.exch_ctx.exch.send(self.tx, &mut self.exch_ctx.sess)?; + + Ok(self.tx.as_mut_slice()) + } +} diff --git a/matter/src/transport/proto_demux.rs b/matter/src/transport/proto_demux.rs deleted file mode 100644 index 263ffc92..00000000 --- a/matter/src/transport/proto_demux.rs +++ /dev/null @@ -1,95 +0,0 @@ -/* - * - * Copyright (c) 2020-2022 Project CHIP Authors - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -use boxslab::BoxSlab; - -use crate::error::*; - -use super::exchange::ExchangeCtx; -use super::packet::PacketPool; - -const MAX_PROTOCOLS: usize = 4; - -#[derive(PartialEq, Debug)] -pub enum ResponseRequired { - Yes, - No, -} -pub struct ProtoDemux { - proto_id_handlers: [Option>; MAX_PROTOCOLS], -} - -/// This is the context in which a receive packet is being processed -pub struct ProtoCtx<'a> { - /// This is the exchange context, that includes the exchange and the session - pub exch_ctx: ExchangeCtx<'a>, - /// This is the received buffer for this transaction - pub rx: BoxSlab, - /// This is the transmit buffer for this transaction - pub tx: BoxSlab, -} - -impl<'a> ProtoCtx<'a> { - pub fn new( - exch_ctx: ExchangeCtx<'a>, - rx: BoxSlab, - tx: BoxSlab, - ) -> Self { - Self { exch_ctx, rx, tx } - } -} - -pub trait HandleProto { - fn handle_proto_id(&mut self, proto_ctx: &mut ProtoCtx) -> Result; - - fn get_proto_id(&self) -> usize; - - fn handle_session_event(&self) -> Result<(), Error> { - Ok(()) - } -} - -impl Default for ProtoDemux { - fn default() -> Self { - Self::new() - } -} - -impl ProtoDemux { - pub fn new() -> ProtoDemux { - ProtoDemux { - proto_id_handlers: [None, None, None, None], - } - } - - pub fn register(&mut self, proto_id_handle: Box) -> Result<(), Error> { - let proto_id = proto_id_handle.get_proto_id(); - self.proto_id_handlers[proto_id] = Some(proto_id_handle); - Ok(()) - } - - pub fn handle(&mut self, proto_ctx: &mut ProtoCtx) -> Result { - let proto_id = proto_ctx.rx.get_proto_id() as usize; - if proto_id >= MAX_PROTOCOLS { - return Err(Error::Invalid); - } - return self.proto_id_handlers[proto_id] - .as_mut() - .ok_or(Error::NoHandler)? - .handle_proto_id(proto_ctx); - } -} diff --git a/matter/src/transport/proto_hdr.rs b/matter/src/transport/proto_hdr.rs index 3eb85701..96928ac2 100644 --- a/matter/src/transport/proto_hdr.rs +++ b/matter/src/transport/proto_hdr.rs @@ -16,7 +16,7 @@ */ use bitflags::bitflags; -use std::fmt; +use core::fmt; use crate::transport::plain_hdr; use crate::utils::parsebuf::ParseBuf; @@ -117,7 +117,7 @@ impl ProtoHdr { if self.is_ack() { self.ack_msg_ctr = Some(parsebuf.le_u32()?); } - trace!("[rx payload]: {:x?}", parsebuf.as_borrow_slice()); + trace!("[rx payload]: {:x?}", parsebuf.as_mut_slice()); Ok(()) } @@ -139,21 +139,21 @@ impl ProtoHdr { impl fmt::Display for ProtoHdr { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - let mut flag_str: String = "".to_owned(); + let mut flag_str = heapless::String::<16>::new(); if self.is_vendor() { - flag_str.push_str("V|"); + flag_str.push_str("V|").unwrap(); } if self.is_security_ext() { - flag_str.push_str("SX|"); + flag_str.push_str("SX|").unwrap(); } if self.is_reliable() { - flag_str.push_str("R|"); + flag_str.push_str("R|").unwrap(); } if self.is_ack() { - flag_str.push_str("A|"); + flag_str.push_str("A|").unwrap(); } if self.is_initiator() { - flag_str.push_str("I|"); + flag_str.push_str("I|").unwrap(); } write!( f, @@ -165,7 +165,7 @@ impl fmt::Display for ProtoHdr { fn get_iv(recvd_ctr: u32, peer_nodeid: u64, iv: &mut [u8]) -> Result<(), Error> { // The IV is the source address (64-bit) followed by the message counter (32-bit) - let mut write_buf = WriteBuf::new(iv, iv.len()); + let mut write_buf = WriteBuf::new(iv); // For some reason, this is 0 in the 'bypass' mode write_buf.le_u8(0)?; write_buf.le_u32(recvd_ctr)?; @@ -224,7 +224,7 @@ fn decrypt_in_place( let mut iv = [0_u8; crypto::AEAD_NONCE_LEN_BYTES]; get_iv(recvd_ctr, peer_nodeid, &mut iv)?; - let cipher_text = parsebuf.as_borrow_slice(); + let cipher_text = parsebuf.as_mut_slice(); //println!("AAD: {:x?}", aad); //println!("Cipher Text: {:x?}", cipher_text); //println!("IV: {:x?}", iv); @@ -266,8 +266,7 @@ mod tests { 0x1f, 0xb0, 0x5e, 0xbe, 0xb5, 0x10, 0xad, 0xc6, 0x78, 0x94, 0x50, 0xe5, 0xd2, 0xe0, 0x80, 0xef, 0xa8, 0x3a, 0xf0, 0xa6, 0xaf, 0x1b, 0x2, 0x35, 0xa7, 0xd1, 0xc6, 0x32, ]; - let input_buf_len = input_buf.len(); - let mut parsebuf = ParseBuf::new(&mut input_buf, input_buf_len); + let mut parsebuf = ParseBuf::new(&mut input_buf); let key = [ 0x66, 0x63, 0x31, 0x97, 0x43, 0x9c, 0x17, 0xb9, 0x7e, 0x10, 0xee, 0x47, 0xc8, 0x8, 0x80, 0x4a, @@ -279,7 +278,7 @@ mod tests { decrypt_in_place(recvd_ctr, 0, &mut parsebuf, &key).unwrap(); assert_eq!( - parsebuf.as_slice(), + parsebuf.into_slice(), [ 0x5, 0x8, 0x70, 0x0, 0x1, 0x0, 0x15, 0x28, 0x0, 0x28, 0x1, 0x36, 0x2, 0x15, 0x37, 0x0, 0x24, 0x0, 0x0, 0x24, 0x1, 0x30, 0x24, 0x2, 0x2, 0x18, 0x35, 0x1, 0x24, 0x0, @@ -295,8 +294,7 @@ mod tests { let send_ctr = 41; let mut main_buf: [u8; 52] = [0; 52]; - let main_buf_len = main_buf.len(); - let mut writebuf = WriteBuf::new(&mut main_buf, main_buf_len); + let mut writebuf = WriteBuf::new(&mut main_buf); let plain_hdr: [u8; 8] = [0x0, 0x11, 0x0, 0x0, 0x29, 0x0, 0x0, 0x0]; @@ -313,7 +311,7 @@ mod tests { encrypt_in_place(send_ctr, 0, &plain_hdr, &mut writebuf, &key).unwrap(); assert_eq!( - writebuf.as_slice(), + writebuf.into_slice(), [ 189, 83, 250, 121, 38, 87, 97, 17, 153, 78, 243, 20, 36, 11, 131, 142, 136, 165, 227, 107, 204, 129, 193, 153, 42, 131, 138, 254, 22, 190, 76, 244, 116, 45, 156, diff --git a/matter/src/transport/session.rs b/matter/src/transport/session.rs index 8faf8135..d4c49852 100644 --- a/matter/src/transport/session.rs +++ b/matter/src/transport/session.rs @@ -15,11 +15,14 @@ * limitations under the License. */ +use crate::data_model::sdm::noc::NocData; +use crate::utils::epoch::Epoch; +use crate::utils::rand::Rand; use core::fmt; -use std::{ +use core::time::Duration; +use core::{ any::Any, ops::{Deref, DerefMut}, - time::SystemTime, }; use crate::{ @@ -27,16 +30,10 @@ use crate::{ transport::{plain_hdr, proto_hdr}, utils::writebuf::WriteBuf, }; -use boxslab::{BoxSlab, Slab}; -use colored::*; use log::{info, trace}; -use rand::Rng; -use super::{ - dedup::RxCtrState, - network::{Address, NetworkInterface}, - packet::{Packet, PacketPool}, -}; +use super::dedup::RxCtrState; +use super::{network::Address, packet::Packet}; pub const MAX_CAT_IDS_PER_NOC: usize = 3; pub type NocCatIds = [u32; MAX_CAT_IDS_PER_NOC]; @@ -58,21 +55,15 @@ impl CaseDetails { } } -#[derive(Debug, PartialEq, Copy, Clone)] +#[derive(Debug, PartialEq, Copy, Clone, Default)] pub enum SessionMode { // The Case session will capture the local fabric index Case(CaseDetails), Pase, + #[default] PlainText, } -impl Default for SessionMode { - fn default() -> Self { - SessionMode::PlainText - } -} - -#[derive(Debug)] pub struct Session { peer_addr: Address, local_nodeid: u64, @@ -87,8 +78,8 @@ pub struct Session { msg_ctr: u32, rx_ctr_state: RxCtrState, mode: SessionMode, - data: Option>, - last_use: SystemTime, + data: Option, + last_use: Duration, } #[derive(Debug)] @@ -103,6 +94,7 @@ pub struct CloneData { peer_addr: Address, mode: SessionMode, } + impl CloneData { pub fn new( local_nodeid: u64, @@ -129,8 +121,8 @@ impl CloneData { const MATTER_MSG_CTR_RANGE: u32 = 0x0fffffff; impl Session { - pub fn new(peer_addr: Address, peer_nodeid: Option) -> Session { - Session { + pub fn new(peer_addr: Address, peer_nodeid: Option, epoch: Epoch, rand: Rand) -> Self { + Self { peer_addr, local_nodeid: 0, peer_nodeid, @@ -139,16 +131,16 @@ impl Session { att_challenge: [0; MATTER_AES128_KEY_SIZE], peer_sess_id: 0, local_sess_id: 0, - msg_ctr: rand::thread_rng().gen_range(0..MATTER_MSG_CTR_RANGE), + msg_ctr: Self::rand_msg_ctr(rand), rx_ctr_state: RxCtrState::new(0), mode: SessionMode::PlainText, data: None, - last_use: SystemTime::now(), + last_use: epoch(), } } // A new encrypted session always clones from a previous 'new' session - pub fn clone(clone_from: &CloneData) -> Session { + pub fn clone(clone_from: &CloneData, epoch: Epoch, rand: Rand) -> Session { Session { peer_addr: clone_from.peer_addr, local_nodeid: clone_from.local_nodeid, @@ -158,28 +150,28 @@ impl Session { att_challenge: clone_from.att_challenge, local_sess_id: clone_from.local_sess_id, peer_sess_id: clone_from.peer_sess_id, - msg_ctr: rand::thread_rng().gen_range(0..MATTER_MSG_CTR_RANGE), + msg_ctr: Self::rand_msg_ctr(rand), rx_ctr_state: RxCtrState::new(0), mode: clone_from.mode, data: None, - last_use: SystemTime::now(), + last_use: epoch(), } } - pub fn set_data(&mut self, data: Box) { + pub fn set_noc_data(&mut self, data: NocData) { self.data = Some(data); } - pub fn clear_data(&mut self) { + pub fn clear_noc_data(&mut self) { self.data = None; } - pub fn get_data(&mut self) -> Option<&mut T> { - self.data.as_mut()?.downcast_mut::() + pub fn get_noc_data(&mut self) -> Option<&mut NocData> { + self.data.as_mut() } - pub fn take_data(&mut self) -> Option> { - self.data.take()?.downcast::().ok() + pub fn take_noc_data(&mut self) -> Option { + self.data.take() } pub fn get_local_sess_id(&self) -> u16 { @@ -252,59 +244,65 @@ impl Session { &self.att_challenge } - pub fn recv(&mut self, proto_rx: &mut Packet) -> Result<(), Error> { - self.last_use = SystemTime::now(); - proto_rx.proto_decode(self.peer_nodeid.unwrap_or_default(), self.get_dec_key()) + pub fn recv(&mut self, epoch: Epoch, rx: &mut Packet) -> Result<(), Error> { + self.last_use = epoch(); + rx.proto_decode(self.peer_nodeid.unwrap_or_default(), self.get_dec_key()) } - pub fn pre_send(&mut self, proto_tx: &mut Packet) -> Result<(), Error> { - proto_tx.plain.sess_id = self.get_peer_sess_id(); - proto_tx.plain.ctr = self.get_msg_ctr(); + pub fn pre_send(&mut self, tx: &mut Packet) -> Result<(), Error> { + tx.plain.sess_id = self.get_peer_sess_id(); + tx.plain.ctr = self.get_msg_ctr(); if self.is_encrypted() { - proto_tx.plain.sess_type = plain_hdr::SessionType::Encrypted; + tx.plain.sess_type = plain_hdr::SessionType::Encrypted; } Ok(()) } // TODO: Most of this can now be moved into the 'Packet' module - fn do_send(&mut self, proto_tx: &mut Packet) -> Result<(), Error> { - self.last_use = SystemTime::now(); - proto_tx.peer = self.peer_addr; + fn do_send(&mut self, epoch: Epoch, tx: &mut Packet) -> Result<(), Error> { + self.last_use = epoch(); + tx.peer = self.peer_addr; // Generate encrypted header - let mut tmp_buf: [u8; proto_hdr::max_proto_hdr_len()] = [0; proto_hdr::max_proto_hdr_len()]; - let mut write_buf = WriteBuf::new(&mut tmp_buf[..], proto_hdr::max_proto_hdr_len()); - proto_tx.proto.encode(&mut write_buf)?; - proto_tx.get_writebuf()?.prepend(write_buf.as_slice())?; + let mut tmp_buf = [0_u8; proto_hdr::max_proto_hdr_len()]; + let mut write_buf = WriteBuf::new(&mut tmp_buf); + tx.proto.encode(&mut write_buf)?; + tx.get_writebuf()?.prepend(write_buf.into_slice())?; // Generate plain-text header if self.mode == SessionMode::PlainText { if let Some(d) = self.peer_nodeid { - proto_tx.plain.set_dest_u64(d); + tx.plain.set_dest_u64(d); } } - let mut tmp_buf: [u8; plain_hdr::max_plain_hdr_len()] = [0; plain_hdr::max_plain_hdr_len()]; - let mut write_buf = WriteBuf::new(&mut tmp_buf[..], plain_hdr::max_plain_hdr_len()); - proto_tx.plain.encode(&mut write_buf)?; - let plain_hdr_bytes = write_buf.as_slice(); + let mut tmp_buf = [0_u8; plain_hdr::max_plain_hdr_len()]; + let mut write_buf = WriteBuf::new(&mut tmp_buf); + tx.plain.encode(&mut write_buf)?; + let plain_hdr_bytes = write_buf.into_slice(); - trace!("unencrypted packet: {:x?}", proto_tx.as_borrow_slice()); - let ctr = proto_tx.plain.ctr; + trace!("unencrypted packet: {:x?}", tx.as_mut_slice()); + let ctr = tx.plain.ctr; let enc_key = self.get_enc_key(); if let Some(e) = enc_key { proto_hdr::encrypt_in_place( ctr, self.local_nodeid, plain_hdr_bytes, - proto_tx.get_writebuf()?, + tx.get_writebuf()?, e, )?; } - proto_tx.get_writebuf()?.prepend(plain_hdr_bytes)?; - trace!("Full encrypted packet: {:x?}", proto_tx.as_borrow_slice()); + tx.get_writebuf()?.prepend(plain_hdr_bytes)?; + trace!("Full encrypted packet: {:x?}", tx.as_mut_slice()); Ok(()) } + + fn rand_msg_ctr(rand: Rand) -> u32 { + let mut buf = [0; 4]; + rand(&mut buf); + u32::from_be_bytes(buf) & MATTER_MSG_CTR_RANGE + } } impl fmt::Display for Session { @@ -324,36 +322,23 @@ impl fmt::Display for Session { } pub const MAX_SESSIONS: usize = 16; + pub struct SessionMgr { next_sess_id: u16, sessions: [Option; MAX_SESSIONS], - network: Option>, -} - -impl Default for SessionMgr { - fn default() -> Self { - Self::new() - } + epoch: Epoch, + rand: Rand, } impl SessionMgr { - pub fn new() -> SessionMgr { - SessionMgr { - sessions: Default::default(), - next_sess_id: 1, - network: None, - } - } + pub fn new(epoch: Epoch, rand: Rand) -> Self { + const INIT: Option = None; - pub fn add_network_interface( - &mut self, - interface: Box, - ) -> Result<(), Error> { - if self.network.is_none() { - self.network = Some(interface); - Ok(()) - } else { - Err(Error::Invalid) + Self { + sessions: [INIT; MAX_SESSIONS], + next_sess_id: 1, + epoch, + rand, } } @@ -380,13 +365,21 @@ impl SessionMgr { next_sess_id } + pub fn get_session_for_eviction(&self) -> Option { + if self.get_empty_slot().is_none() { + Some(self.get_lru()) + } else { + None + } + } + fn get_empty_slot(&self) -> Option { self.sessions.iter().position(|x| x.is_none()) } - pub fn get_lru(&mut self) -> usize { + fn get_lru(&self) -> usize { let mut lru_index = 0; - let mut lru_ts = SystemTime::now(); + let mut lru_ts = (self.epoch)(); for i in 0..MAX_SESSIONS { if let Some(s) = &self.sessions[i] { if s.last_use < lru_ts { @@ -399,7 +392,7 @@ impl SessionMgr { } pub fn add(&mut self, peer_addr: Address, peer_nodeid: Option) -> Result { - let session = Session::new(peer_addr, peer_nodeid); + let session = Session::new(peer_addr, peer_nodeid, self.epoch, self.rand); self.add_session(session) } @@ -422,7 +415,7 @@ impl SessionMgr { } pub fn clone_session(&mut self, clone_data: &CloneData) -> Result { - let session = Session::clone(clone_data); + let session = Session::clone(clone_data, self.epoch, self.rand); self.add_session(session) } @@ -478,68 +471,50 @@ impl SessionMgr { // We will try to get a session for this Packet. If no session exists, we will try to add one // If the session list is full we will return a None - pub fn post_recv(&mut self, rx: &Packet) -> Result, Error> { - let sess_index = match self.get_or_add( + pub fn post_recv(&mut self, rx: &Packet) -> Result { + let sess_index = self.get_or_add( rx.plain.sess_id, rx.peer, rx.plain.get_src_u64(), rx.plain.is_encrypted(), - ) { - Ok(s) => { - let session = self.sessions[s].as_mut().unwrap(); - let is_encrypted = session.is_encrypted(); - let duplicate = session.rx_ctr_state.recv(rx.plain.ctr, is_encrypted); - if duplicate { - info!("Dropping duplicate packet"); - return Err(Error::Duplicate); - } else { - Some(s) - } - } - Err(Error::NoSpace) => None, - Err(e) => { - return Err(e); - } - }; - Ok(sess_index) + )?; + + let session = self.sessions[sess_index].as_mut().unwrap(); + let is_encrypted = session.is_encrypted(); + let duplicate = session.rx_ctr_state.recv(rx.plain.ctr, is_encrypted); + if duplicate { + info!("Dropping duplicate packet"); + Err(Error::Duplicate) + } else { + Ok(sess_index) + } } - pub fn recv(&mut self) -> Result<(BoxSlab, Option), Error> { - let mut rx = - Slab::::try_new(Packet::new_rx()?).ok_or(Error::PacketPoolExhaust)?; - - let network = self.network.as_ref().ok_or(Error::NoNetworkInterface)?; + pub fn decode(&mut self, rx: &mut Packet) -> Result<(), Error> { + // let network = self.network.as_ref().ok_or(Error::NoNetworkInterface)?; - let (len, src) = network.recv(rx.as_borrow_slice())?; - rx.get_parsebuf()?.set_len(len); - rx.peer = src; + // let (len, src) = network.recv(rx.as_borrow_slice()).await?; + // rx.get_parsebuf()?.set_len(len); + // rx.peer = src; - info!("{} from src: {}", "Received".blue(), src); - trace!("payload: {:x?}", rx.as_borrow_slice()); + // info!("{} from src: {}", "Received".blue(), src); + // trace!("payload: {:x?}", rx.as_borrow_slice()); // Read unencrypted packet header - rx.plain_hdr_decode()?; - - // Get session - let sess_handle = self.post_recv(&rx)?; - - Ok((rx, sess_handle)) + rx.plain_hdr_decode() } - pub fn send( - &mut self, - sess_idx: usize, - mut proto_tx: BoxSlab, - ) -> Result<(), Error> { + pub fn send(&mut self, sess_idx: usize, tx: &mut Packet) -> Result<(), Error> { self.sessions[sess_idx] .as_mut() .ok_or(Error::NoSession)? - .do_send(&mut proto_tx)?; + .do_send(self.epoch, tx)?; + + // let network = self.network.as_ref().ok_or(Error::NoNetworkInterface)?; + // let peer = proto_tx.peer; + // network.send(proto_tx.as_borrow_slice(), peer).await?; + // info!("Message Sent to {}", peer); - let network = self.network.as_ref().ok_or(Error::NoNetworkInterface)?; - let peer = proto_tx.peer; - network.send(proto_tx.as_borrow_slice(), peer)?; - println!("Message Sent to {}", peer); Ok(()) } @@ -568,40 +543,52 @@ pub struct SessionHandle<'a> { } impl<'a> SessionHandle<'a> { + pub fn session(&self) -> &Session { + self.sess_mgr.sessions[self.sess_idx].as_ref().unwrap() + } + + pub fn session_mut(&mut self) -> &mut Session { + self.sess_mgr.sessions[self.sess_idx].as_mut().unwrap() + } + pub fn reserve_new_sess_id(&mut self) -> u16 { self.sess_mgr.get_next_sess_id() } - pub fn send(&mut self, proto_tx: BoxSlab) -> Result<(), Error> { - self.sess_mgr.send(self.sess_idx, proto_tx) + pub fn send(&mut self, tx: &mut Packet) -> Result<(), Error> { + self.sess_mgr.send(self.sess_idx, tx) } } impl<'a> Deref for SessionHandle<'a> { type Target = Session; + fn deref(&self) -> &Self::Target { // There is no other option but to panic if this is None - self.sess_mgr.sessions[self.sess_idx].as_ref().unwrap() + self.session() } } impl<'a> DerefMut for SessionHandle<'a> { fn deref_mut(&mut self) -> &mut Self::Target { // There is no other option but to panic if this is None - self.sess_mgr.sessions[self.sess_idx].as_mut().unwrap() + self.session_mut() } } #[cfg(test)] mod tests { - use crate::transport::network::Address; + use crate::{ + transport::network::Address, + utils::{epoch::dummy_epoch, rand::dummy_rand}, + }; use super::SessionMgr; #[test] fn test_next_sess_id_doesnt_reuse() { - let mut sm = SessionMgr::new(); + let mut sm = SessionMgr::new(dummy_epoch, dummy_rand); let sess_idx = sm.add(Address::default(), None).unwrap(); let mut sess = sm.get_session_handle(sess_idx); sess.set_local_sess_id(1); @@ -615,7 +602,7 @@ mod tests { #[test] fn test_next_sess_id_overflows() { - let mut sm = SessionMgr::new(); + let mut sm = SessionMgr::new(dummy_epoch, dummy_rand); let sess_idx = sm.add(Address::default(), None).unwrap(); let mut sess = sm.get_session_handle(sess_idx); sess.set_local_sess_id(1); diff --git a/matter/src/transport/udp.rs b/matter/src/transport/udp.rs index 8abe9af1..6f7a2651 100644 --- a/matter/src/transport/udp.rs +++ b/matter/src/transport/udp.rs @@ -16,9 +16,10 @@ */ use crate::error::*; +use log::info; use smol::net::{Ipv6Addr, UdpSocket}; -use super::network::{Address, NetworkInterface}; +use super::network::Address; // We could get rid of the smol here, but keeping it around in case we have to process // any other events in this thread's context @@ -33,25 +34,26 @@ pub const MAX_RX_BUF_SIZE: usize = 1583; pub const MATTER_PORT: u16 = 5540; impl UdpListener { - pub fn new() -> Result { + pub async fn new() -> Result { Ok(UdpListener { - socket: smol::block_on(UdpSocket::bind((Ipv6Addr::UNSPECIFIED, MATTER_PORT)))?, + socket: UdpSocket::bind((Ipv6Addr::UNSPECIFIED, MATTER_PORT)).await?, }) } -} -impl NetworkInterface for UdpListener { - fn recv(&self, in_buf: &mut [u8]) -> Result<(usize, Address), Error> { - let (size, addr) = smol::block_on(self.socket.recv_from(in_buf)).map_err(|e| { - println!("Error on the network: {:?}", e); + pub async fn recv(&self, in_buf: &mut [u8]) -> Result<(usize, Address), Error> { + let (size, addr) = self.socket.recv_from(in_buf).await.map_err(|e| { + info!("Error on the network: {:?}", e); Error::Network })?; Ok((size, Address::Udp(addr))) } - fn send(&self, out_buf: &[u8], addr: Address) -> Result { + pub async fn send(&self, out_buf: &[u8], addr: Address) -> Result { match addr { - Address::Udp(addr) => Ok(smol::block_on(self.socket.send_to(out_buf, addr))?), + Address::Udp(addr) => self.socket.send_to(out_buf, addr).await.map_err(|e| { + info!("Error on the network: {:?}", e); + Error::Network + }), } } } diff --git a/matter/src/utils/epoch.rs b/matter/src/utils/epoch.rs new file mode 100644 index 00000000..999cdf38 --- /dev/null +++ b/matter/src/utils/epoch.rs @@ -0,0 +1,14 @@ +use core::time::Duration; + +pub type Epoch = fn() -> Duration; + +pub fn dummy_epoch() -> Duration { + Duration::from_secs(0) +} + +#[cfg(feature = "std")] +pub fn sys_epoch() -> Duration { + std::time::SystemTime::now() + .duration_since(std::time::UNIX_EPOCH) + .unwrap() +} diff --git a/matter/src/utils/mod.rs b/matter/src/utils/mod.rs index 9fc44a87..1e69b847 100644 --- a/matter/src/utils/mod.rs +++ b/matter/src/utils/mod.rs @@ -15,5 +15,7 @@ * limitations under the License. */ +pub mod epoch; pub mod parsebuf; +pub mod rand; pub mod writebuf; diff --git a/matter/src/utils/parsebuf.rs b/matter/src/utils/parsebuf.rs index b5342c03..d6a8b9a9 100644 --- a/matter/src/utils/parsebuf.rs +++ b/matter/src/utils/parsebuf.rs @@ -25,11 +25,13 @@ pub struct ParseBuf<'a> { } impl<'a> ParseBuf<'a> { - pub fn new(buf: &'a mut [u8], len: usize) -> ParseBuf<'a> { - ParseBuf { - buf: &mut buf[..len], + pub fn new(buf: &'a mut [u8]) -> Self { + let left = buf.len(); + + Self { + buf, read_off: 0, - left: len, + left, } } @@ -38,12 +40,17 @@ impl<'a> ParseBuf<'a> { } // Return the data that is valid as a slice, consume self - pub fn as_slice(self) -> &'a mut [u8] { + pub fn into_slice(self) -> &'a mut [u8] { &mut self.buf[self.read_off..(self.read_off + self.left)] } // Return the data that is valid as a slice - pub fn as_borrow_slice(&mut self) -> &mut [u8] { + pub fn as_slice(&self) -> &[u8] { + &self.buf[self.read_off..(self.read_off + self.left)] + } + + // Return the data that is valid as a slice + pub fn as_mut_slice(&mut self) -> &mut [u8] { &mut self.buf[self.read_off..(self.read_off + self.left)] } @@ -101,19 +108,19 @@ mod tests { #[test] fn test_parse_with_success() { - let mut test_slice: [u8; 11] = [0x01, 65, 0, 0xbe, 0xba, 0xfe, 0xca, 0xa, 0xb, 0xc, 0xd]; - let mut buf = ParseBuf::new(&mut test_slice, 11); + let mut test_slice = [0x01, 65, 0, 0xbe, 0xba, 0xfe, 0xca, 0xa, 0xb, 0xc, 0xd]; + let mut buf = ParseBuf::new(&mut test_slice); assert_eq!(buf.le_u8().unwrap(), 0x01); assert_eq!(buf.le_u16().unwrap(), 65); assert_eq!(buf.le_u32().unwrap(), 0xcafebabe); - assert_eq!(buf.as_slice(), [0xa, 0xb, 0xc, 0xd]); + assert_eq!(buf.into_slice(), [0xa, 0xb, 0xc, 0xd]); } #[test] fn test_parse_with_overrun() { - let mut test_slice: [u8; 2] = [0x01, 65]; - let mut buf = ParseBuf::new(&mut test_slice, 2); + let mut test_slice = [0x01, 65]; + let mut buf = ParseBuf::new(&mut test_slice); assert_eq!(buf.le_u8().unwrap(), 0x01); @@ -131,29 +138,29 @@ mod tests { if buf.le_u8().is_ok() { panic!("This should have returned error") } - assert_eq!(buf.as_slice(), []); + assert_eq!(buf.into_slice(), []); } #[test] fn test_tail_with_success() { - let mut test_slice: [u8; 11] = [0x01, 65, 0, 0xbe, 0xba, 0xfe, 0xca, 0xa, 0xb, 0xc, 0xd]; - let mut buf = ParseBuf::new(&mut test_slice, 11); + let mut test_slice = [0x01, 65, 0, 0xbe, 0xba, 0xfe, 0xca, 0xa, 0xb, 0xc, 0xd]; + let mut buf = ParseBuf::new(&mut test_slice); assert_eq!(buf.le_u8().unwrap(), 0x01); assert_eq!(buf.le_u16().unwrap(), 65); assert_eq!(buf.le_u32().unwrap(), 0xcafebabe); assert_eq!(buf.tail(2).unwrap(), [0xc, 0xd]); - assert_eq!(buf.as_borrow_slice(), [0xa, 0xb]); + assert_eq!(buf.as_mut_slice(), [0xa, 0xb]); assert_eq!(buf.tail(2).unwrap(), [0xa, 0xb]); - assert_eq!(buf.as_slice(), []); + assert_eq!(buf.into_slice(), []); } #[test] fn test_tail_with_overrun() { - let mut test_slice: [u8; 11] = [0x01, 65, 0, 0xbe, 0xba, 0xfe, 0xca, 0xa, 0xb, 0xc, 0xd]; - let mut buf = ParseBuf::new(&mut test_slice, 11); + let mut test_slice = [0x01, 65, 0, 0xbe, 0xba, 0xfe, 0xca, 0xa, 0xb, 0xc, 0xd]; + let mut buf = ParseBuf::new(&mut test_slice); assert_eq!(buf.le_u8().unwrap(), 0x01); assert_eq!(buf.le_u16().unwrap(), 65); @@ -166,8 +173,8 @@ mod tests { #[test] fn test_parsed_as_slice() { - let mut test_slice: [u8; 11] = [0x01, 65, 0, 0xbe, 0xba, 0xfe, 0xca, 0xa, 0xb, 0xc, 0xd]; - let mut buf = ParseBuf::new(&mut test_slice, 11); + let mut test_slice = [0x01, 65, 0, 0xbe, 0xba, 0xfe, 0xca, 0xa, 0xb, 0xc, 0xd]; + let mut buf = ParseBuf::new(&mut test_slice); assert_eq!(buf.parsed_as_slice(), []); assert_eq!(buf.le_u8().unwrap(), 0x1); diff --git a/matter/src/utils/rand.rs b/matter/src/utils/rand.rs new file mode 100644 index 00000000..3cd698ca --- /dev/null +++ b/matter/src/utils/rand.rs @@ -0,0 +1,3 @@ +pub type Rand = fn(&mut [u8]); + +pub fn dummy_rand(_buf: &mut [u8]) {} diff --git a/matter/src/utils/writebuf.rs b/matter/src/utils/writebuf.rs index cf28888b..fae44818 100644 --- a/matter/src/utils/writebuf.rs +++ b/matter/src/utils/writebuf.rs @@ -53,9 +53,9 @@ pub struct WriteBuf<'a> { } impl<'a> WriteBuf<'a> { - pub fn new(buf: &'a mut [u8], len: usize) -> WriteBuf<'a> { - WriteBuf { - buf: &mut buf[..len], + pub fn new(buf: &'a mut [u8]) -> Self { + Self { + buf, start: 0, end: 0, } @@ -73,11 +73,11 @@ impl<'a> WriteBuf<'a> { self.end += new_offset } - pub fn as_borrow_slice(&self) -> &[u8] { + pub fn into_slice(self) -> &'a [u8] { &self.buf[self.start..self.end] } - pub fn as_slice(self) -> &'a [u8] { + pub fn as_slice(&self) -> &[u8] { &self.buf[self.start..self.end] } @@ -201,9 +201,8 @@ mod tests { #[test] fn test_append_le_with_success() { - let mut test_slice: [u8; 22] = [0; 22]; - let test_slice_len = test_slice.len(); - let mut buf = WriteBuf::new(&mut test_slice, test_slice_len); + let mut test_slice = [0; 22]; + let mut buf = WriteBuf::new(&mut test_slice); buf.reserve(5).unwrap(); buf.le_u8(1).unwrap(); @@ -222,8 +221,8 @@ mod tests { #[test] fn test_len_param() { - let mut test_slice: [u8; 20] = [0; 20]; - let mut buf = WriteBuf::new(&mut test_slice, 5); + let mut test_slice = [0; 20]; + let mut buf = WriteBuf::new(&mut test_slice[..5]); buf.reserve(5).unwrap(); let _ = buf.le_u8(1); @@ -236,8 +235,8 @@ mod tests { #[test] fn test_overrun() { - let mut test_slice: [u8; 20] = [0; 20]; - let mut buf = WriteBuf::new(&mut test_slice, 20); + let mut test_slice = [0; 20]; + let mut buf = WriteBuf::new(&mut test_slice); buf.reserve(4).unwrap(); buf.le_u64(0xcafebabecafebabe).unwrap(); buf.le_u64(0xcafebabecafebabe).unwrap(); @@ -262,8 +261,8 @@ mod tests { #[test] fn test_as_slice() { - let mut test_slice: [u8; 20] = [0; 20]; - let mut buf = WriteBuf::new(&mut test_slice, 20); + let mut test_slice = [0; 20]; + let mut buf = WriteBuf::new(&mut test_slice); buf.reserve(5).unwrap(); buf.le_u8(1).unwrap(); @@ -275,7 +274,7 @@ mod tests { buf.prepend(&new_slice).unwrap(); assert_eq!( - buf.as_slice(), + buf.into_slice(), [ 0xa, 0xb, 0xc, 1, 65, 0, 0xbe, 0xba, 0xfe, 0xca, 0xbe, 0xba, 0xfe, 0xca, 0xbe, 0xba, 0xfe, 0xca @@ -285,8 +284,8 @@ mod tests { #[test] fn test_copy_as_slice() { - let mut test_slice: [u8; 20] = [0; 20]; - let mut buf = WriteBuf::new(&mut test_slice, 20); + let mut test_slice = [0; 20]; + let mut buf = WriteBuf::new(&mut test_slice); buf.reserve(5).unwrap(); buf.le_u16(65).unwrap(); @@ -301,8 +300,8 @@ mod tests { #[test] fn test_copy_as_slice_overrun() { - let mut test_slice: [u8; 20] = [0; 20]; - let mut buf = WriteBuf::new(&mut test_slice, 7); + let mut test_slice = [0; 20]; + let mut buf = WriteBuf::new(&mut test_slice[..7]); buf.reserve(5).unwrap(); buf.le_u16(65).unwrap(); @@ -314,8 +313,8 @@ mod tests { #[test] fn test_prepend() { - let mut test_slice: [u8; 20] = [0; 20]; - let mut buf = WriteBuf::new(&mut test_slice, 20); + let mut test_slice = [0; 20]; + let mut buf = WriteBuf::new(&mut test_slice); buf.reserve(5).unwrap(); buf.le_u16(65).unwrap(); @@ -329,8 +328,8 @@ mod tests { #[test] fn test_prepend_overrun() { - let mut test_slice: [u8; 20] = [0; 20]; - let mut buf = WriteBuf::new(&mut test_slice, 20); + let mut test_slice = [0; 20]; + let mut buf = WriteBuf::new(&mut test_slice); buf.reserve(5).unwrap(); buf.le_u16(65).unwrap(); @@ -342,8 +341,8 @@ mod tests { #[test] fn test_rewind_tail() { - let mut test_slice: [u8; 20] = [0; 20]; - let mut buf = WriteBuf::new(&mut test_slice, 20); + let mut test_slice = [0; 20]; + let mut buf = WriteBuf::new(&mut test_slice); buf.reserve(5).unwrap(); buf.le_u16(65).unwrap(); @@ -352,13 +351,10 @@ mod tests { let new_slice: [u8; 5] = [0xaa, 0xbb, 0xcc, 0xdd, 0xee]; buf.copy_from_slice(&new_slice).unwrap(); - assert_eq!( - buf.as_borrow_slice(), - [65, 0, 0xaa, 0xbb, 0xcc, 0xdd, 0xee,] - ); + assert_eq!(buf.as_slice(), [65, 0, 0xaa, 0xbb, 0xcc, 0xdd, 0xee,]); buf.rewind_tail_to(anchor); buf.le_u16(66).unwrap(); - assert_eq!(buf.as_borrow_slice(), [65, 0, 66, 0,]); + assert_eq!(buf.as_slice(), [65, 0, 66, 0,]); } } diff --git a/matter/tests/common/attributes.rs b/matter/tests/common/attributes.rs index 1879adf2..2ff95eb6 100644 --- a/matter/tests/common/attributes.rs +++ b/matter/tests/common/attributes.rs @@ -115,8 +115,7 @@ impl TLVHolder { buf: [0; 100], used_len: 0, }; - let buf_len = s.buf.len(); - let mut wb = WriteBuf::new(&mut s.buf, buf_len); + let mut wb = WriteBuf::new(&mut s.buf); let mut tw = TLVWriter::new(&mut wb); let _ = tw.start_array(TagType::Context(ctx_tag)); for e in data { diff --git a/matter/tests/common/commands.rs b/matter/tests/common/commands.rs index 919565a1..419b6ac4 100644 --- a/matter/tests/common/commands.rs +++ b/matter/tests/common/commands.rs @@ -76,7 +76,7 @@ macro_rules! echo_req { CmdPath::new( Some($endpoint), Some(echo_cluster::ID), - Some(echo_cluster::Commands::EchoReq as u16), + Some(echo_cluster::Commands::EchoReq as u32), ), EncodeValue::Value(&($data as u32)), ) @@ -90,7 +90,7 @@ macro_rules! echo_resp { CmdPath::new( Some($endpoint), Some(echo_cluster::ID), - Some(echo_cluster::Commands::EchoResp as u16), + Some(echo_cluster::RespCommands::EchoResp as u32), ), $data, ) diff --git a/matter/tests/common/echo_cluster.rs b/matter/tests/common/echo_cluster.rs index cf071830..dd61a0e7 100644 --- a/matter/tests/common/echo_cluster.rs +++ b/matter/tests/common/echo_cluster.rs @@ -15,31 +15,91 @@ * limitations under the License. */ -use std::sync::{Arc, Mutex, Once}; +use std::{ + convert::TryInto, + sync::{Arc, Mutex, Once}, +}; use matter::{ + attribute_enum, command_enum, data_model::objects::{ - Access, AttrDetails, AttrValue, Attribute, Cluster, ClusterType, EncodeValue, Encoder, - Quality, + Access, AttrData, AttrDataEncoder, AttrDataWriter, AttrDetails, AttrType, Attribute, + Cluster, CmdDataEncoder, CmdDataWriter, CmdDetails, Dataver, Handler, Quality, + ATTRIBUTE_LIST, FEATURE_MAP, }, error::Error, interaction_model::{ - command::CommandReq, - core::IMStatusCode, - messages::ib::{self, attr_list_write, ListOperation}, + core::Transaction, + messages::ib::{attr_list_write, ListOperation}, }, - tlv::{TLVElement, TLVWriter, TagType, ToTLV}, + tlv::{TLVElement, TagType}, + utils::rand::Rand, }; use num_derive::FromPrimitive; +use strum::{EnumDiscriminants, FromRepr}; pub const ID: u32 = 0xABCD; -#[derive(FromPrimitive)] +#[derive(FromRepr, EnumDiscriminants)] +#[repr(u16)] +pub enum Attributes { + Att1(AttrType) = 0, + Att2(AttrType) = 1, + AttWrite(AttrType) = 2, + AttCustom(AttrType) = 3, + AttWriteList(()) = 4, +} + +attribute_enum!(Attributes); + +#[derive(FromRepr)] +#[repr(u32)] pub enum Commands { EchoReq = 0x00, +} + +command_enum!(Commands); + +#[derive(FromPrimitive)] +pub enum RespCommands { EchoResp = 0x01, } +pub const CLUSTER: Cluster<'static> = Cluster { + id: ID, + feature_map: 0, + attributes: &[ + FEATURE_MAP, + ATTRIBUTE_LIST, + Attribute::new( + AttributesDiscriminants::Att1 as u16, + Access::RV, + Quality::NONE, + ), + Attribute::new( + AttributesDiscriminants::Att2 as u16, + Access::RV, + Quality::NONE, + ), + Attribute::new( + AttributesDiscriminants::AttWrite as u16, + Access::WRITE.union(Access::NEED_ADMIN), + Quality::NONE, + ), + Attribute::new( + AttributesDiscriminants::AttCustom as u16, + Access::READ.union(Access::NEED_VIEW), + Quality::NONE, + ), + Attribute::new( + AttributesDiscriminants::AttWriteList as u16, + Access::WRITE.union(Access::NEED_ADMIN), + Quality::NONE, + ), + ], + commands: &[Commands::EchoReq as _], +}; + /// This is used in the tests to validate any settings that may have happened /// to the custom data parts of the cluster pub struct TestChecker { @@ -68,167 +128,122 @@ impl TestChecker { } pub const WRITE_LIST_MAX: usize = 5; + pub struct EchoCluster { - pub base: Cluster, + pub data_ver: Dataver, pub multiplier: u8, + pub att1: u16, + pub att2: u16, + pub att_write: u16, + pub att_custom: u32, } -#[derive(FromPrimitive)] -pub enum Attributes { - Att1 = 0, - Att2 = 1, - AttWrite = 2, - AttCustom = 3, - AttWriteList = 4, -} - -pub const ATTR_CUSTOM_VALUE: u32 = 0xcafebeef; -pub const ATTR_WRITE_DEFAULT_VALUE: u16 = 0xcafe; - -impl ClusterType for EchoCluster { - fn base(&self) -> &Cluster { - &self.base +impl EchoCluster { + pub fn new(multiplier: u8, rand: Rand) -> Self { + Self { + data_ver: Dataver::new(rand), + multiplier, + att1: 0x1234, + att2: 0x5678, + att_write: ATTR_WRITE_DEFAULT_VALUE, + att_custom: ATTR_CUSTOM_VALUE, + } } - fn base_mut(&mut self) -> &mut Cluster { - &mut self.base - } + pub fn read(&self, attr: &AttrDetails, encoder: AttrDataEncoder) -> Result<(), Error> { + if let Some(mut writer) = encoder.with_dataver(self.data_ver.get())? { + if attr.is_system() { + CLUSTER.read(attr.attr_id, writer) + } else { + match attr.attr_id.try_into()? { + Attributes::Att1(codec) => codec.encode(writer, 0x1234), + Attributes::Att2(codec) => codec.encode(writer, 0x5678), + Attributes::AttWrite(codec) => codec.encode(writer, ATTR_WRITE_DEFAULT_VALUE), + Attributes::AttCustom(codec) => codec.encode(writer, ATTR_CUSTOM_VALUE), + Attributes::AttWriteList(_) => { + let tc_handle = TestChecker::get().unwrap(); + let tc = tc_handle.lock().unwrap(); + + writer.start_array(AttrDataWriter::TAG)?; + for i in tc.write_list.iter().flatten() { + writer.u16(TagType::Anonymous, *i)?; + } + writer.end_container()?; - fn read_custom_attribute(&self, encoder: &mut dyn Encoder, attr: &AttrDetails) { - match num::FromPrimitive::from_u16(attr.attr_id) { - Some(Attributes::AttCustom) => encoder.encode(EncodeValue::Closure(&|tag, tw| { - let _ = tw.u32(tag, ATTR_CUSTOM_VALUE); - })), - Some(Attributes::AttWriteList) => { - let tc_handle = TestChecker::get().unwrap(); - let tc = tc_handle.lock().unwrap(); - encoder.encode(EncodeValue::Closure(&|tag, tw| { - let _ = tw.start_array(tag); - for i in tc.write_list.iter().flatten() { - let _ = tw.u16(TagType::Anonymous, *i); + writer.complete() } - let _ = tw.end_container(); - })) + } } - _ => (), + } else { + Ok(()) } } - fn write_attribute( - &mut self, - attr: &AttrDetails, - data: &TLVElement, - ) -> Result<(), IMStatusCode> { - match num::FromPrimitive::from_u16(attr.attr_id) { - Some(Attributes::AttWriteList) => { - attr_list_write(attr, data, |op, data| self.write_attr_list(&op, data)) + pub fn write(&mut self, attr: &AttrDetails, data: AttrData) -> Result<(), Error> { + let data = data.with_dataver(self.data_ver.get())?; + + match attr.attr_id.try_into()? { + Attributes::Att1(codec) => self.att1 = codec.decode(data)?, + Attributes::Att2(codec) => self.att2 = codec.decode(data)?, + Attributes::AttWrite(codec) => self.att_write = codec.decode(data)?, + Attributes::AttCustom(codec) => self.att_custom = codec.decode(data)?, + Attributes::AttWriteList(_) => { + attr_list_write(attr, data, |op, data| self.write_attr_list(&op, data))? } - _ => self.base.write_attribute_from_tlv(attr.attr_id, data), } + + self.data_ver.changed(); + + Ok(()) } - fn handle_command(&mut self, cmd_req: &mut CommandReq) -> Result<(), IMStatusCode> { - let cmd = cmd_req - .cmd - .path - .leaf - .map(num::FromPrimitive::from_u32) - .ok_or(IMStatusCode::UnsupportedCommand)? - .ok_or(IMStatusCode::UnsupportedCommand)?; - match cmd { + pub fn invoke( + &mut self, + _transaction: &mut Transaction, + cmd: &CmdDetails, + data: &TLVElement, + encoder: CmdDataEncoder, + ) -> Result<(), Error> { + match cmd.cmd_id.try_into()? { // This will generate an echo response on the same endpoint // with data multiplied by the multiplier Commands::EchoReq => { - let a = cmd_req.data.u8().unwrap(); - let mut echo_response = cmd_req.cmd; - echo_response.path.leaf = Some(Commands::EchoResp as u32); - - let cmd_data = |tag: TagType, t: &mut TLVWriter| { - let _ = t.start_struct(tag); - // Echo = input * self.multiplier - let _ = t.u8(TagType::Context(0), a * self.multiplier); - let _ = t.end_container(); - }; - - let invoke_resp = ib::InvResp::Cmd(ib::CmdData::new( - echo_response, - EncodeValue::Closure(&cmd_data), - )); - let _ = invoke_resp.to_tlv(cmd_req.resp, TagType::Anonymous); - cmd_req.trans.complete(); - } - _ => { - return Err(IMStatusCode::UnsupportedCommand); + let a = data.u8()?; + + let mut writer = encoder.with_command(RespCommands::EchoResp as _)?; + + writer.start_struct(CmdDataWriter::TAG)?; + // Echo = input * self.multiplier + writer.u8(TagType::Context(0), a * self.multiplier)?; + writer.end_container()?; + + writer.complete() } } - Ok(()) } -} -impl EchoCluster { - pub fn new(multiplier: u8) -> Result, Error> { - let mut c = Box::new(Self { - base: Cluster::new(ID)?, - multiplier, - }); - c.base.add_attribute(Attribute::new( - Attributes::Att1 as u16, - AttrValue::Uint16(0x1234), - Access::RV, - Quality::NONE, - ))?; - c.base.add_attribute(Attribute::new( - Attributes::Att2 as u16, - AttrValue::Uint16(0x5678), - Access::RV, - Quality::NONE, - ))?; - c.base.add_attribute(Attribute::new( - Attributes::AttWrite as u16, - AttrValue::Uint16(ATTR_WRITE_DEFAULT_VALUE), - Access::WRITE | Access::NEED_ADMIN, - Quality::NONE, - ))?; - c.base.add_attribute(Attribute::new( - Attributes::AttCustom as u16, - AttrValue::Custom, - Access::READ | Access::NEED_VIEW, - Quality::NONE, - ))?; - c.base.add_attribute(Attribute::new( - Attributes::AttWriteList as u16, - AttrValue::Custom, - Access::WRITE | Access::NEED_ADMIN, - Quality::NONE, - ))?; - Ok(c) - } - - fn write_attr_list( - &mut self, - op: &ListOperation, - data: &TLVElement, - ) -> Result<(), IMStatusCode> { + fn write_attr_list(&mut self, op: &ListOperation, data: &TLVElement) -> Result<(), Error> { let tc_handle = TestChecker::get().unwrap(); let mut tc = tc_handle.lock().unwrap(); match op { ListOperation::AddItem => { - let data = data.u16().map_err(|_| IMStatusCode::Failure)?; + let data = data.u16()?; for i in 0..WRITE_LIST_MAX { if tc.write_list[i].is_none() { tc.write_list[i] = Some(data); return Ok(()); } } - Err(IMStatusCode::ResourceExhausted) + + Err(Error::ResourceExhausted) } ListOperation::EditItem(index) => { - let data = data.u16().map_err(|_| IMStatusCode::Failure)?; + let data = data.u16()?; if tc.write_list[*index as usize].is_some() { tc.write_list[*index as usize] = Some(data); Ok(()) } else { - Err(IMStatusCode::InvalidAction) + Err(Error::InvalidAction) } } ListOperation::DeleteItem(index) => { @@ -236,7 +251,7 @@ impl EchoCluster { tc.write_list[*index as usize] = None; Ok(()) } else { - Err(IMStatusCode::InvalidAction) + Err(Error::InvalidAction) } } ListOperation::DeleteList => { @@ -248,3 +263,26 @@ impl EchoCluster { } } } + +pub const ATTR_CUSTOM_VALUE: u32 = 0xcafebeef; +pub const ATTR_WRITE_DEFAULT_VALUE: u16 = 0xcafe; + +impl Handler for EchoCluster { + fn read(&self, attr: &AttrDetails, encoder: AttrDataEncoder) -> Result<(), Error> { + EchoCluster::read(self, attr, encoder) + } + + fn write(&mut self, attr: &AttrDetails, data: AttrData) -> Result<(), Error> { + EchoCluster::write(self, attr, data) + } + + fn invoke( + &mut self, + transaction: &mut Transaction, + cmd: &CmdDetails, + data: &TLVElement, + encoder: CmdDataEncoder, + ) -> Result<(), Error> { + EchoCluster::invoke(self, transaction, cmd, data, encoder) + } +} diff --git a/matter/tests/common/handlers.rs b/matter/tests/common/handlers.rs new file mode 100644 index 00000000..7235b8a3 --- /dev/null +++ b/matter/tests/common/handlers.rs @@ -0,0 +1,317 @@ +use core::time; +use std::thread; + +use log::{info, warn}; +use matter::{ + interaction_model::{ + core::{IMStatusCode, OpCode}, + messages::{ + ib::{AttrData, AttrPath, AttrResp, AttrStatus, CmdData, DataVersionFilter}, + msg::{ + self, InvReq, ReadReq, ReportDataMsg, StatusResp, TimedReq, WriteReq, WriteResp, + WriteRespTag, + }, + }, + }, + tlv::{self, FromTLV, TLVArray, ToTLV}, + transport::{ + exchange::{self, Exchange}, + session::NocCatIds, + }, + Matter, +}; + +use super::{ + attributes::assert_attr_report, + commands::{assert_inv_response, ExpectedInvResp}, + im_engine::{ImEngine, ImInput, IM_ENGINE_PEER_ID}, +}; + +pub enum WriteResponse<'a> { + TransactionError, + TransactionSuccess(&'a [AttrStatus]), +} + +pub enum TimedInvResponse<'a> { + TransactionError(IMStatusCode), + TransactionSuccess(&'a [ExpectedInvResp]), +} + +impl<'a> ImEngine<'a> { + // Helper for handling Read Req sequences for this file + pub fn handle_read_reqs( + &mut self, + peer_node_id: u64, + input: &[AttrPath], + expected: &[AttrResp], + ) { + let mut out_buf = [0u8; 400]; + let received = self.gen_read_reqs_output(peer_node_id, input, None, &mut out_buf); + assert_attr_report(&received, expected) + } + + pub fn new_with_read_reqs( + matter: &'a Matter<'a>, + input: &[AttrPath], + expected: &[AttrResp], + ) -> Self { + let mut im = Self::new(matter); + + let mut out_buf = [0u8; 400]; + let received = im.gen_read_reqs_output(IM_ENGINE_PEER_ID, input, None, &mut out_buf); + assert_attr_report(&received, expected); + + im + } + + pub fn gen_read_reqs_output<'b>( + &mut self, + peer_node_id: u64, + input: &[AttrPath], + dataver_filters: Option>, + out_buf: &'b mut [u8], + ) -> ReportDataMsg<'b> { + let mut read_req = ReadReq::new(true).set_attr_requests(input); + read_req.dataver_filters = dataver_filters; + + let mut input = ImInput::new(OpCode::ReadRequest, &read_req); + input.set_peer_node_id(peer_node_id); + + let (_, out_buf) = self.process(&input, out_buf); + + tlv::print_tlv_list(out_buf); + let root = tlv::get_root_node_struct(out_buf).unwrap(); + ReportDataMsg::from_tlv(&root).unwrap() + } + + pub fn handle_write_reqs( + &mut self, + peer_node_id: u64, + peer_cat_ids: Option<&NocCatIds>, + input: &[AttrData], + expected: &[AttrStatus], + ) { + let mut out_buf = [0u8; 400]; + let write_req = WriteReq::new(false, input); + + let mut input = ImInput::new(OpCode::WriteRequest, &write_req); + input.set_peer_node_id(peer_node_id); + if let Some(cat_ids) = peer_cat_ids { + input.set_cat_ids(cat_ids); + } + + let (_, out_buf) = self.process(&input, &mut out_buf); + + tlv::print_tlv_list(out_buf); + let root = tlv::get_root_node_struct(out_buf).unwrap(); + + let mut index = 0; + let response_iter = root + .find_tag(WriteRespTag::WriteResponses as u32) + .unwrap() + .confirm_array() + .unwrap() + .enter() + .unwrap(); + + for response in response_iter { + info!("Validating index {}", index); + let status = AttrStatus::from_tlv(&response).unwrap(); + assert_eq!(expected[index], status); + info!("Index {} success", index); + index += 1; + } + assert_eq!(index, expected.len()); + } + + pub fn new_with_write_reqs( + matter: &'a Matter<'a>, + input: &[AttrData], + expected: &[AttrStatus], + ) -> Self { + let mut im = Self::new(matter); + + im.handle_write_reqs(IM_ENGINE_PEER_ID, None, input, expected); + + im + } + + // Helper for handling Invoke Command sequences + pub fn handle_commands( + &mut self, + peer_node_id: u64, + input: &[CmdData], + expected: &[ExpectedInvResp], + ) { + let mut out_buf = [0u8; 400]; + let req = InvReq { + suppress_response: Some(false), + timed_request: Some(false), + inv_requests: Some(TLVArray::Slice(input)), + }; + + let mut input = ImInput::new(OpCode::InvokeRequest, &req); + input.set_peer_node_id(peer_node_id); + + let (_, out_buf) = self.process(&input, &mut out_buf); + tlv::print_tlv_list(out_buf); + let root = tlv::get_root_node_struct(out_buf).unwrap(); + let resp = msg::InvResp::from_tlv(&root).unwrap(); + assert_inv_response(&resp, expected) + } + + pub fn new_with_commands( + matter: &'a Matter<'a>, + input: &[CmdData], + expected: &[ExpectedInvResp], + ) -> Self { + let mut im = ImEngine::new(matter); + + im.handle_commands(IM_ENGINE_PEER_ID, input, expected); + + im + } + + fn handle_timed_reqs<'b>( + &mut self, + opcode: OpCode, + request: &dyn ToTLV, + timeout: u16, + delay: u16, + output: &'b mut [u8], + ) -> (u8, &'b [u8]) { + // Use the same exchange for all parts of the transaction + self.exch = Some(Exchange::new(1, 0, exchange::Role::Responder)); + + if timeout != 0 { + // Send Timed Req + let mut tmp_buf = [0u8; 400]; + let timed_req = TimedReq { timeout }; + let im_input = ImInput::new(OpCode::TimedRequest, &timed_req); + let (_, out_buf) = self.process(&im_input, &mut tmp_buf); + tlv::print_tlv_list(out_buf); + } else { + warn!("Skipping timed request"); + } + + // Process any delays + let delay = time::Duration::from_millis(delay.into()); + thread::sleep(delay); + + // Send Write Req + let input = ImInput::new(opcode, request); + let (resp_opcode, output) = self.process(&input, output); + (resp_opcode, output) + } + + // Helper for handling Write Attribute sequences + pub fn handle_timed_write_reqs( + &mut self, + input: &[AttrData], + expected: &WriteResponse, + timeout: u16, + delay: u16, + ) { + let mut out_buf = [0u8; 400]; + let write_req = WriteReq::new(false, input); + + let (resp_opcode, out_buf) = self.handle_timed_reqs( + OpCode::WriteRequest, + &write_req, + timeout, + delay, + &mut out_buf, + ); + tlv::print_tlv_list(out_buf); + let root = tlv::get_root_node_struct(out_buf).unwrap(); + + match expected { + WriteResponse::TransactionSuccess(t) => { + assert_eq!( + num::FromPrimitive::from_u8(resp_opcode), + Some(OpCode::WriteResponse) + ); + let resp = WriteResp::from_tlv(&root).unwrap(); + assert_eq!(resp.write_responses, t); + } + WriteResponse::TransactionError => { + assert_eq!( + num::FromPrimitive::from_u8(resp_opcode), + Some(OpCode::StatusResponse) + ); + let status_resp = StatusResp::from_tlv(&root).unwrap(); + assert_eq!(status_resp.status, IMStatusCode::Timeout); + } + } + } + + pub fn new_with_timed_write_reqs( + matter: &'a Matter<'a>, + input: &[AttrData], + expected: &WriteResponse, + timeout: u16, + delay: u16, + ) -> Self { + let mut im = ImEngine::new(matter); + + im.handle_timed_write_reqs(input, expected, timeout, delay); + + im + } + + // Helper for handling Invoke Command sequences + pub fn handle_timed_commands( + &mut self, + input: &[CmdData], + expected: &TimedInvResponse, + timeout: u16, + delay: u16, + set_timed_request: bool, + ) { + let mut out_buf = [0u8; 400]; + let req = InvReq { + suppress_response: Some(false), + timed_request: Some(set_timed_request), + inv_requests: Some(TLVArray::Slice(input)), + }; + + let (resp_opcode, out_buf) = + self.handle_timed_reqs(OpCode::InvokeRequest, &req, timeout, delay, &mut out_buf); + tlv::print_tlv_list(out_buf); + let root = tlv::get_root_node_struct(out_buf).unwrap(); + + match expected { + TimedInvResponse::TransactionSuccess(t) => { + assert_eq!( + num::FromPrimitive::from_u8(resp_opcode), + Some(OpCode::InvokeResponse) + ); + let resp = msg::InvResp::from_tlv(&root).unwrap(); + assert_inv_response(&resp, t) + } + TimedInvResponse::TransactionError(e) => { + assert_eq!( + num::FromPrimitive::from_u8(resp_opcode), + Some(OpCode::StatusResponse) + ); + let status_resp = StatusResp::from_tlv(&root).unwrap(); + assert_eq!(status_resp.status, *e); + } + } + } + + pub fn new_with_timed_commands( + matter: &'a Matter<'a>, + input: &[CmdData], + expected: &TimedInvResponse, + timeout: u16, + delay: u16, + set_timed_request: bool, + ) -> Self { + let mut im = ImEngine::new(matter); + + im.handle_timed_commands(input, expected, timeout, delay, set_timed_request); + + im + } +} diff --git a/matter/tests/common/im_engine.rs b/matter/tests/common/im_engine.rs index f91433c6..348ce74a 100644 --- a/matter/tests/common/im_engine.rs +++ b/matter/tests/common/im_engine.rs @@ -16,54 +16,60 @@ */ use crate::common::echo_cluster; -use boxslab::Slab; +use core::borrow::Borrow; use matter::{ - acl::{AclEntry, AclMgr, AuthMode}, + acl::{AclEntry, AuthMode}, data_model::{ - cluster_basic_information::BasicInfoConfig, + cluster_basic_information::{self, BasicInfoConfig}, + cluster_on_off::{self, OnOffCluster}, core::DataModel, - device_types::device_type_add_on_off_light, - objects::Privilege, - sdm::dev_att::{DataType, DevAttDataFetcher}, + device_types::{DEV_TYPE_ON_OFF_LIGHT, DEV_TYPE_ROOT_NODE}, + objects::{ChainedHandler, Endpoint, Node, Privilege}, + root_endpoint::{self, RootEndpointHandler}, + sdm::{ + admin_commissioning, + dev_att::{DataType, DevAttDataFetcher}, + general_commissioning, noc, nw_commissioning, + }, + system_model::access_control, }, error::Error, - fabric::FabricMgr, - interaction_model::{core::OpCode, InteractionModel}, - secure_channel::pake::PaseMgr, + interaction_model::core::{InteractionModel, OpCode}, + mdns::Mdns, tlv::{TLVWriter, TagType, ToTLV}, transport::packet::Packet, transport::{ exchange::{self, Exchange, ExchangeCtx}, network::Address, - packet::PacketPool, - proto_demux::ProtoCtx, - session::{CloneData, NocCatIds, SessionMgr, SessionMode}, + proto_ctx::ProtoCtx, + session::{CaseDetails, CloneData, NocCatIds, SessionMgr, SessionMode}, }, - transport::{proto_demux::HandleProto, session::CaseDetails}, - utils::writebuf::WriteBuf, + utils::{epoch::sys_epoch, rand::dummy_rand, writebuf::WriteBuf}, + Matter, }; -use std::{ - net::{Ipv4Addr, SocketAddr}, - sync::Arc, +use std::net::{Ipv4Addr, SocketAddr}; + +use super::echo_cluster::EchoCluster; + +const BASIC_INFO: BasicInfoConfig<'static> = BasicInfoConfig { + vid: 10, + pid: 11, + hw_ver: 12, + sw_ver: 13, + sw_ver_str: "13", + serial_no: "aabbccdd", + device_name: "Test Device", }; pub struct DummyDevAtt {} + impl DevAttDataFetcher for DummyDevAtt { fn get_devatt_data(&self, _data_type: DataType, _data: &mut [u8]) -> Result { Ok(2) } } -/// An Interaction Model Engine to facilitate easy testing -pub struct ImEngine { - pub dm: DataModel, - pub acl_mgr: Arc, - pub im: Box, - // By default, a new exchange is created for every run, if you wish to instead using a specific - // exchange, set this variable. This is helpful in situations where you have to run multiple - // actions in the same transaction (exchange) - pub exch: Option, -} +pub const IM_ENGINE_PEER_ID: u64 = 445566; pub struct ImInput<'a> { action: OpCode, @@ -72,7 +78,6 @@ pub struct ImInput<'a> { cat_ids: NocCatIds, } -pub const IM_ENGINE_PEER_ID: u64 = 445566; impl<'a> ImInput<'a> { pub fn new(action: OpCode, data: &'a dyn ToTLV) -> Self { Self { @@ -92,56 +97,86 @@ impl<'a> ImInput<'a> { } } -impl ImEngine { - /// Create the interaction model engine - pub fn new() -> Self { - let dev_det = BasicInfoConfig { - vid: 10, - pid: 11, - hw_ver: 12, - sw_ver: 13, - sw_ver_str: "13".to_string(), - serial_no: "aabbccdd".to_string(), - device_name: "Test Device".to_string(), - }; +pub type DmHandler<'a> = ChainedHandler< + OnOffCluster, + ChainedHandler>>, +>; + +pub fn matter<'a>(mdns: &'a mut dyn Mdns) -> Matter<'_> { + Matter::new(&BASIC_INFO, mdns, sys_epoch, dummy_rand) +} + +/// An Interaction Model Engine to facilitate easy testing +pub struct ImEngine<'a> { + pub matter: &'a Matter<'a>, + pub im: InteractionModel>>, + // By default, a new exchange is created for every run, if you wish to instead using a specific + // exchange, set this variable. This is helpful in situations where you have to run multiple + // actions in the same transaction (exchange) + pub exch: Option, +} - let dev_att = Box::new(DummyDevAtt {}); - let fabric_mgr = Arc::new(FabricMgr::new().unwrap()); - let acl_mgr = Arc::new(AclMgr::new_with(false).unwrap()); - let pase_mgr = PaseMgr::new(); - acl_mgr.erase_all(); +impl<'a> ImEngine<'a> { + /// Create the interaction model engine + pub fn new(matter: &'a Matter<'a>) -> Self { let mut default_acl = AclEntry::new(1, Privilege::ADMIN, AuthMode::Case); // Only allow the standard peer node id of the IM Engine default_acl.add_subject(IM_ENGINE_PEER_ID).unwrap(); - acl_mgr.add(default_acl).unwrap(); - let dm = DataModel::new(dev_det, dev_att, fabric_mgr, acl_mgr.clone(), pase_mgr).unwrap(); - - { - let mut d = dm.node.write().unwrap(); - let light_endpoint = device_type_add_on_off_light(&mut d).unwrap(); - d.add_cluster(0, echo_cluster::EchoCluster::new(2).unwrap()) - .unwrap(); - d.add_cluster(light_endpoint, echo_cluster::EchoCluster::new(3).unwrap()) - .unwrap(); - } - - let im = Box::new(InteractionModel::new(Box::new(dm.clone()))); + matter.acl_mgr.borrow_mut().add(default_acl).unwrap(); + + let dm = DataModel::new( + matter.borrow(), + &Node { + id: 0, + endpoints: &[ + Endpoint { + id: 0, + clusters: &[ + cluster_basic_information::CLUSTER, + general_commissioning::CLUSTER, + nw_commissioning::CLUSTER, + admin_commissioning::CLUSTER, + noc::CLUSTER, + access_control::CLUSTER, + echo_cluster::CLUSTER, + ], + device_type: DEV_TYPE_ROOT_NODE, + }, + Endpoint { + id: 1, + clusters: &[echo_cluster::CLUSTER, cluster_on_off::CLUSTER], + device_type: DEV_TYPE_ON_OFF_LIGHT, + }, + ], + }, + root_endpoint::handler(0, &DummyDevAtt {}, matter) + .chain(0, echo_cluster::ID, EchoCluster::new(2, *matter.borrow())) + .chain(1, echo_cluster::ID, EchoCluster::new(3, *matter.borrow())) + .chain(1, cluster_on_off::ID, OnOffCluster::new(*matter.borrow())), + ); Self { - dm, - acl_mgr, - im, + matter, + im: InteractionModel(dm), exch: None, } } + pub fn echo_cluster(&self, endpoint: u16) -> &EchoCluster { + match endpoint { + 0 => &self.im.0.handler.next.next.handler, + 1 => &self.im.0.handler.next.handler, + _ => panic!(), + } + } + /// Run a transaction through the interaction model engine - pub fn process<'a>(&mut self, input: &ImInput, data_out: &'a mut [u8]) -> (u8, &'a mut [u8]) { + pub fn process<'b>(&mut self, input: &ImInput, data_out: &'b mut [u8]) -> (u8, &'b [u8]) { let mut new_exch = Exchange::new(1, 0, exchange::Role::Responder); // Choose whether to use a new exchange, or use the one from the ImEngine configuration let exch = self.exch.as_mut().unwrap_or(&mut new_exch); - let mut sess_mgr: SessionMgr = Default::default(); + let mut sess_mgr = SessionMgr::new(*self.matter.borrow(), *self.matter.borrow()); let clone_data = CloneData::new( 123456, @@ -156,9 +191,15 @@ impl ImEngine { ); let sess_idx = sess_mgr.clone_session(&clone_data).unwrap(); let sess = sess_mgr.get_session_handle(sess_idx); - let exch_ctx = ExchangeCtx { exch, sess }; - let mut rx = Slab::::try_new(Packet::new_rx().unwrap()).unwrap(); - let tx = Slab::::try_new(Packet::new_tx().unwrap()).unwrap(); + let exch_ctx = ExchangeCtx { + exch, + sess, + epoch: *self.matter.borrow(), + }; + let mut tx_buf = [0; 1500]; + let mut rx_buf = [0; 1500]; + let mut rx = Packet::new_rx(&mut rx_buf); + let mut tx = Packet::new_tx(&mut tx_buf); // Create fake rx packet rx.set_proto_id(0x01); rx.set_proto_opcode(input.action as u8); @@ -166,36 +207,37 @@ impl ImEngine { { let mut buf = [0u8; 400]; - let buf_len = buf.len(); - let mut wb = WriteBuf::new(&mut buf, buf_len); + let mut wb = WriteBuf::new(&mut buf); let mut tw = TLVWriter::new(&mut wb); input.data.to_tlv(&mut tw, TagType::Anonymous).unwrap(); - let input_data = wb.as_borrow_slice(); + let input_data = wb.as_slice(); let in_data_len = input_data.len(); - let rx_buf = rx.as_borrow_slice(); + let rx_buf = rx.as_mut_slice(); rx_buf[..in_data_len].copy_from_slice(input_data); rx.get_parsebuf().unwrap().set_len(in_data_len); } - let mut ctx = ProtoCtx::new(exch_ctx, rx, tx); - self.im.handle_proto_id(&mut ctx).unwrap(); - let out_data_len = ctx.tx.as_borrow_slice().len(); - data_out[..out_data_len].copy_from_slice(ctx.tx.as_borrow_slice()); + let mut ctx = ProtoCtx::new(exch_ctx, &rx, &mut tx); + self.im.handle(&mut ctx).unwrap(); + let out_data_len = ctx.tx.as_slice().len(); + data_out[..out_data_len].copy_from_slice(ctx.tx.as_slice()); let response = ctx.tx.get_proto_opcode(); - (response, &mut data_out[..out_data_len]) + (response, &data_out[..out_data_len]) } } -// Create an Interaction Model, Data Model and run a rx/tx transaction through it -pub fn im_engine<'a>( - action: OpCode, - data: &dyn ToTLV, - data_out: &'a mut [u8], -) -> (DataModel, u8, &'a mut [u8]) { - let mut engine = ImEngine::new(); - let input = ImInput::new(action, data); - let (response, output) = engine.process(&input, data_out); - (engine.dm, response, output) -} +// TODO - Remove? +// // Create an Interaction Model, Data Model and run a rx/tx transaction through it +// pub fn im_engine<'a>( +// matter: &'a Matter, +// action: OpCode, +// data: &dyn ToTLV, +// data_out: &'a mut [u8], +// ) -> (DmHandler<'a>, u8, &'a mut [u8]) { +// let mut engine = ImEngine::new(matter); +// let input = ImInput::new(action, data); +// let (response, output) = engine.process(&input, data_out); +// (engine.dm.handler, response, output) +// } diff --git a/matter/tests/common/mod.rs b/matter/tests/common/mod.rs index dea136a0..0d2cc9c8 100644 --- a/matter/tests/common/mod.rs +++ b/matter/tests/common/mod.rs @@ -18,4 +18,5 @@ pub mod attributes; pub mod commands; pub mod echo_cluster; +pub mod handlers; pub mod im_engine; diff --git a/matter/tests/data_model/acl_and_dataver.rs b/matter/tests/data_model/acl_and_dataver.rs index 493a282a..535555ba 100644 --- a/matter/tests/data_model/acl_and_dataver.rs +++ b/matter/tests/data_model/acl_and_dataver.rs @@ -18,19 +18,16 @@ use matter::{ acl::{gen_noc_cat, AclEntry, AuthMode, Target}, data_model::{ - objects::{AttrValue, EncodeValue, Privilege}, + objects::{EncodeValue, Privilege}, system_model::access_control, }, interaction_model::{ - core::{IMStatusCode, OpCode}, - messages::{ - ib::{AttrData, AttrPath, AttrResp, AttrStatus, ClusterPath, DataVersionFilter}, - msg::{ReadReq, ReportDataMsg, WriteReq}, - }, - messages::{msg, GenericPath}, + core::IMStatusCode, + messages::ib::{AttrData, AttrPath, AttrResp, AttrStatus, ClusterPath, DataVersionFilter}, + messages::GenericPath, }, - tlv::{self, ElementType, FromTLV, TLVArray, TLVElement, TLVWriter, TagType}, - transport::session::NocCatIds, + mdns::DummyMdns, + tlv::{ElementType, TLVArray, TLVElement, TLVWriter, TagType}, }; use crate::{ @@ -38,81 +35,10 @@ use crate::{ common::{ attributes::*, echo_cluster::{self, ATTR_WRITE_DEFAULT_VALUE}, - im_engine::{ImEngine, ImInput}, + im_engine::{matter, ImEngine}, }, }; -// Helper for handling Read Req sequences for this file -fn handle_read_reqs( - im: &mut ImEngine, - peer_node_id: u64, - input: &[AttrPath], - expected: &[AttrResp], -) { - let mut out_buf = [0u8; 400]; - let received = gen_read_reqs_output(im, peer_node_id, input, None, &mut out_buf); - assert_attr_report(&received, expected) -} - -fn gen_read_reqs_output<'a>( - im: &mut ImEngine, - peer_node_id: u64, - input: &[AttrPath], - dataver_filters: Option>, - out_buf: &'a mut [u8], -) -> ReportDataMsg<'a> { - let mut read_req = ReadReq::new(true).set_attr_requests(input); - read_req.dataver_filters = dataver_filters; - - let mut input = ImInput::new(OpCode::ReadRequest, &read_req); - input.set_peer_node_id(peer_node_id); - - let (_, out_buf) = im.process(&input, out_buf); - - tlv::print_tlv_list(out_buf); - let root = tlv::get_root_node_struct(out_buf).unwrap(); - ReportDataMsg::from_tlv(&root).unwrap() -} - -// Helper for handling Write Attribute sequences -fn handle_write_reqs( - im: &mut ImEngine, - peer_node_id: u64, - peer_cat_ids: Option<&NocCatIds>, - input: &[AttrData], - expected: &[AttrStatus], -) { - let mut out_buf = [0u8; 400]; - let write_req = WriteReq::new(false, input); - - let mut input = ImInput::new(OpCode::WriteRequest, &write_req); - input.set_peer_node_id(peer_node_id); - if let Some(cat_ids) = peer_cat_ids { - input.set_cat_ids(cat_ids); - } - let (_, out_buf) = im.process(&input, &mut out_buf); - - tlv::print_tlv_list(out_buf); - let root = tlv::get_root_node_struct(out_buf).unwrap(); - - let mut index = 0; - let response_iter = root - .find_tag(msg::WriteRespTag::WriteResponses as u32) - .unwrap() - .confirm_array() - .unwrap() - .enter() - .unwrap(); - for response in response_iter { - println!("Validating index {}", index); - let status = AttrStatus::from_tlv(&response).unwrap(); - assert_eq!(expected[index], status); - println!("Index {} success", index); - index += 1; - } - assert_eq!(index, expected.len()); -} - #[test] /// Ensure that wildcard read attributes don't include error response /// and silently drop the data when access is not granted @@ -122,43 +48,45 @@ fn wc_read_attribute() { let wc_att1 = GenericPath::new( None, Some(echo_cluster::ID), - Some(echo_cluster::Attributes::Att1 as u32), + Some(echo_cluster::AttributesDiscriminants::Att1 as u32), ); let ep0_att1 = GenericPath::new( Some(0), Some(echo_cluster::ID), - Some(echo_cluster::Attributes::Att1 as u32), + Some(echo_cluster::AttributesDiscriminants::Att1 as u32), ); let ep1_att1 = GenericPath::new( Some(1), Some(echo_cluster::ID), - Some(echo_cluster::Attributes::Att1 as u32), + Some(echo_cluster::AttributesDiscriminants::Att1 as u32), ); let peer = 98765; - let mut im = ImEngine::new(); + let mut mdns = DummyMdns {}; + let matter = matter(&mut mdns); + let mut im = ImEngine::new(&matter); // Test1: Empty Response as no ACL matches let input = &[AttrPath::new(&wc_att1)]; let expected = &[]; - handle_read_reqs(&mut im, peer, input, expected); + im.handle_read_reqs(peer, input, expected); // Add ACL to allow our peer to only access endpoint 0 let mut acl = AclEntry::new(1, Privilege::ADMIN, AuthMode::Case); acl.add_subject(peer).unwrap(); acl.add_target(Target::new(Some(0), None, None)).unwrap(); - im.acl_mgr.add(acl).unwrap(); + im.matter.acl_mgr.borrow_mut().add(acl).unwrap(); // Test2: Only Single response as only single endpoint is allowed let input = &[AttrPath::new(&wc_att1)]; let expected = &[attr_data_path!(ep0_att1, ElementType::U16(0x1234))]; - handle_read_reqs(&mut im, peer, input, expected); + im.handle_read_reqs(peer, input, expected); - // Add ACL to allow our peer to only access endpoint 1 + // Add ACL to allow our peer to also access endpoint 1 let mut acl = AclEntry::new(1, Privilege::ADMIN, AuthMode::Case); acl.add_subject(peer).unwrap(); acl.add_target(Target::new(Some(1), None, None)).unwrap(); - im.acl_mgr.add(acl).unwrap(); + im.matter.acl_mgr.borrow_mut().add(acl).unwrap(); // Test3: Both responses are valid let input = &[AttrPath::new(&wc_att1)]; @@ -166,7 +94,7 @@ fn wc_read_attribute() { attr_data_path!(ep0_att1, ElementType::U16(0x1234)), attr_data_path!(ep1_att1, ElementType::U16(0x1234)), ]; - handle_read_reqs(&mut im, peer, input, expected); + im.handle_read_reqs(peer, input, expected); } #[test] @@ -178,48 +106,33 @@ fn exact_read_attribute() { let wc_att1 = GenericPath::new( Some(0), Some(echo_cluster::ID), - Some(echo_cluster::Attributes::Att1 as u32), + Some(echo_cluster::AttributesDiscriminants::Att1 as u32), ); let ep0_att1 = GenericPath::new( Some(0), Some(echo_cluster::ID), - Some(echo_cluster::Attributes::Att1 as u32), + Some(echo_cluster::AttributesDiscriminants::Att1 as u32), ); let peer = 98765; - let mut im = ImEngine::new(); + let mut mdns = DummyMdns {}; + let matter = matter(&mut mdns); + let mut im = ImEngine::new(&matter); // Test1: Unsupported Access error as no ACL matches let input = &[AttrPath::new(&wc_att1)]; let expected = &[attr_status!(&ep0_att1, IMStatusCode::UnsupportedAccess)]; - handle_read_reqs(&mut im, peer, input, expected); + im.handle_read_reqs(peer, input, expected); // Add ACL to allow our peer to access any endpoint let mut acl = AclEntry::new(1, Privilege::ADMIN, AuthMode::Case); acl.add_subject(peer).unwrap(); - im.acl_mgr.add(acl).unwrap(); + im.matter.acl_mgr.borrow_mut().add(acl).unwrap(); // Test2: Only Single response as only single endpoint is allowed let input = &[AttrPath::new(&wc_att1)]; let expected = &[attr_data_path!(ep0_att1, ElementType::U16(0x1234))]; - handle_read_reqs(&mut im, peer, input, expected); -} - -fn read_cluster_id_write_attr(im: &ImEngine, endpoint: u16) -> AttrValue { - let node = im.dm.node.read().unwrap(); - let echo = node.get_cluster(endpoint, echo_cluster::ID).unwrap(); - - echo.base() - .read_attribute_raw(echo_cluster::Attributes::AttWrite as u16) - .unwrap() - .clone() -} - -fn read_cluster_id_data_ver(im: &ImEngine, endpoint: u16) -> u32 { - let node = im.dm.node.read().unwrap(); - let echo = node.get_cluster(endpoint, echo_cluster::ID).unwrap(); - - echo.base().get_dataver() + im.handle_read_reqs(peer, input, expected); } #[test] @@ -239,17 +152,17 @@ fn wc_write_attribute() { let wc_att = GenericPath::new( None, Some(echo_cluster::ID), - Some(echo_cluster::Attributes::AttWrite as u32), + Some(echo_cluster::AttributesDiscriminants::AttWrite as u32), ); let ep0_att = GenericPath::new( Some(0), Some(echo_cluster::ID), - Some(echo_cluster::Attributes::AttWrite as u32), + Some(echo_cluster::AttributesDiscriminants::AttWrite as u32), ); let ep1_att = GenericPath::new( Some(1), Some(echo_cluster::ID), - Some(echo_cluster::Attributes::AttWrite as u32), + Some(echo_cluster::AttributesDiscriminants::AttWrite as u32), ); let input0 = &[AttrData::new( @@ -264,54 +177,41 @@ fn wc_write_attribute() { )]; let peer = 98765; - let mut im = ImEngine::new(); + let mut mdns = DummyMdns {}; + let matter = matter(&mut mdns); + let mut im = ImEngine::new(&matter); // Test 1: Wildcard write to an attribute without permission should return // no error - handle_write_reqs(&mut im, peer, None, input0, &[]); - { - let node = im.dm.node.read().unwrap(); - let echo = node.get_cluster(0, echo_cluster::ID).unwrap(); - assert_eq!( - AttrValue::Uint16(ATTR_WRITE_DEFAULT_VALUE), - *echo - .base() - .read_attribute_raw(echo_cluster::Attributes::AttWrite as u16) - .unwrap() - ); - } + im.handle_write_reqs(peer, None, input0, &[]); + assert_eq!(ATTR_WRITE_DEFAULT_VALUE, im.echo_cluster(0).att_write); // Add ACL to allow our peer to access one endpoint let mut acl = AclEntry::new(1, Privilege::ADMIN, AuthMode::Case); acl.add_subject(peer).unwrap(); acl.add_target(Target::new(Some(0), None, None)).unwrap(); - im.acl_mgr.add(acl).unwrap(); + im.matter.acl_mgr.borrow_mut().add(acl).unwrap(); // Test 2: Wildcard write to attributes will only return attributes // where the writes were successful - handle_write_reqs( - &mut im, + im.handle_write_reqs( peer, None, input0, &[AttrStatus::new(&ep0_att, IMStatusCode::Success, 0)], ); - assert_eq!(AttrValue::Uint16(val0), read_cluster_id_write_attr(&im, 0)); - assert_eq!( - AttrValue::Uint16(ATTR_WRITE_DEFAULT_VALUE), - read_cluster_id_write_attr(&im, 1) - ); + assert_eq!(val0, im.echo_cluster(0).att_write); + assert_eq!(ATTR_WRITE_DEFAULT_VALUE, im.echo_cluster(1).att_write); // Add ACL to allow our peer to access another endpoint let mut acl = AclEntry::new(1, Privilege::ADMIN, AuthMode::Case); acl.add_subject(peer).unwrap(); acl.add_target(Target::new(Some(1), None, None)).unwrap(); - im.acl_mgr.add(acl).unwrap(); + im.matter.acl_mgr.borrow_mut().add(acl).unwrap(); // Test 3: Wildcard write to attributes will return multiple attributes // where the writes were successful - handle_write_reqs( - &mut im, + im.handle_write_reqs( peer, None, input1, @@ -320,8 +220,8 @@ fn wc_write_attribute() { AttrStatus::new(&ep1_att, IMStatusCode::Success, 0), ], ); - assert_eq!(AttrValue::Uint16(val1), read_cluster_id_write_attr(&im, 0)); - assert_eq!(AttrValue::Uint16(val1), read_cluster_id_write_attr(&im, 1)); + assert_eq!(val1, im.echo_cluster(0).att_write); + assert_eq!(val1, im.echo_cluster(1).att_write); } #[test] @@ -337,7 +237,7 @@ fn exact_write_attribute() { let ep0_att = GenericPath::new( Some(0), Some(echo_cluster::ID), - Some(echo_cluster::Attributes::AttWrite as u32), + Some(echo_cluster::AttributesDiscriminants::AttWrite as u32), ); let input = &[AttrData::new( @@ -353,25 +253,24 @@ fn exact_write_attribute() { let expected_success = &[AttrStatus::new(&ep0_att, IMStatusCode::Success, 0)]; let peer = 98765; - let mut im = ImEngine::new(); + let mut mdns = DummyMdns {}; + let matter = matter(&mut mdns); + let mut im = ImEngine::new(&matter); // Test 1: Exact write to an attribute without permission should return // Unsupported Access Error - handle_write_reqs(&mut im, peer, None, input, expected_fail); - assert_eq!( - AttrValue::Uint16(ATTR_WRITE_DEFAULT_VALUE), - read_cluster_id_write_attr(&im, 0) - ); + im.handle_write_reqs(peer, None, input, expected_fail); + assert_eq!(ATTR_WRITE_DEFAULT_VALUE, im.echo_cluster(0).att_write); // Add ACL to allow our peer to access any endpoint let mut acl = AclEntry::new(1, Privilege::ADMIN, AuthMode::Case); acl.add_subject(peer).unwrap(); - im.acl_mgr.add(acl).unwrap(); + im.matter.acl_mgr.borrow_mut().add(acl).unwrap(); // Test 1: Exact write to an attribute with permission should grant // access - handle_write_reqs(&mut im, peer, None, input, expected_success); - assert_eq!(AttrValue::Uint16(val0), read_cluster_id_write_attr(&im, 0)); + im.handle_write_reqs(peer, None, input, expected_success); + assert_eq!(val0, im.echo_cluster(0).att_write); } #[test] @@ -388,7 +287,7 @@ fn exact_write_attribute_noc_cat() { let ep0_att = GenericPath::new( Some(0), Some(echo_cluster::ID), - Some(echo_cluster::Attributes::AttWrite as u32), + Some(echo_cluster::AttributesDiscriminants::AttWrite as u32), ); let input = &[AttrData::new( @@ -408,25 +307,24 @@ fn exact_write_attribute_noc_cat() { let noc_cat = gen_noc_cat(0xABCD, 2); let cat_in_acl = gen_noc_cat(0xABCD, 1); let cat_ids = [noc_cat, 0, 0]; - let mut im = ImEngine::new(); + let mut mdns = DummyMdns; + let matter = matter(&mut mdns); + let mut im = ImEngine::new(&matter); // Test 1: Exact write to an attribute without permission should return // Unsupported Access Error - handle_write_reqs(&mut im, peer, Some(&cat_ids), input, expected_fail); - assert_eq!( - AttrValue::Uint16(ATTR_WRITE_DEFAULT_VALUE), - read_cluster_id_write_attr(&im, 0) - ); + im.handle_write_reqs(peer, Some(&cat_ids), input, expected_fail); + assert_eq!(ATTR_WRITE_DEFAULT_VALUE, im.echo_cluster(0).att_write); // Add ACL to allow our peer to access any endpoint let mut acl = AclEntry::new(1, Privilege::ADMIN, AuthMode::Case); acl.add_subject_catid(cat_in_acl).unwrap(); - im.acl_mgr.add(acl).unwrap(); + im.matter.acl_mgr.borrow_mut().add(acl).unwrap(); // Test 1: Exact write to an attribute with permission should grant // access - handle_write_reqs(&mut im, peer, Some(&cat_ids), input, expected_success); - assert_eq!(AttrValue::Uint16(val0), read_cluster_id_write_attr(&im, 0)); + im.handle_write_reqs(peer, Some(&cat_ids), input, expected_success); + assert_eq!(val0, im.echo_cluster(0).att_write); } #[test] @@ -440,7 +338,7 @@ fn insufficient_perms_write() { let ep0_att = GenericPath::new( Some(0), Some(echo_cluster::ID), - Some(echo_cluster::Attributes::AttWrite as u32), + Some(echo_cluster::AttributesDiscriminants::AttWrite as u32), ); let input0 = &[AttrData::new( None, @@ -449,17 +347,18 @@ fn insufficient_perms_write() { )]; let peer = 98765; - let mut im = ImEngine::new(); + let mut mdns = DummyMdns {}; + let matter = matter(&mut mdns); + let mut im = ImEngine::new(&matter); // Add ACL to allow our peer with only OPERATE permission let mut acl = AclEntry::new(1, Privilege::OPERATE, AuthMode::Case); acl.add_subject(peer).unwrap(); acl.add_target(Target::new(Some(0), None, None)).unwrap(); - im.acl_mgr.add(acl).unwrap(); + im.matter.acl_mgr.borrow_mut().add(acl).unwrap(); // Test: Not enough permission should return error - handle_write_reqs( - &mut im, + im.handle_write_reqs( peer, None, input0, @@ -469,10 +368,7 @@ fn insufficient_perms_write() { 0, )], ); - assert_eq!( - AttrValue::Uint16(ATTR_WRITE_DEFAULT_VALUE), - read_cluster_id_write_attr(&im, 0) - ); + assert_eq!(ATTR_WRITE_DEFAULT_VALUE, im.echo_cluster(0).att_write); } #[test] @@ -485,7 +381,9 @@ fn insufficient_perms_write() { fn write_with_runtime_acl_add() { let _ = env_logger::try_init(); let peer = 98765; - let mut im = ImEngine::new(); + let mut mdns = DummyMdns {}; + let matter = matter(&mut mdns); + let mut im = ImEngine::new(&matter); let val0 = 10; let attr_data0 = |tag, t: &mut TLVWriter| { @@ -494,7 +392,7 @@ fn write_with_runtime_acl_add() { let ep0_att = GenericPath::new( Some(0), Some(echo_cluster::ID), - Some(echo_cluster::Attributes::AttWrite as u32), + Some(echo_cluster::AttributesDiscriminants::AttWrite as u32), ); let input0 = AttrData::new( None, @@ -509,7 +407,7 @@ fn write_with_runtime_acl_add() { let acl_att = GenericPath::new( Some(0), Some(access_control::ID), - Some(access_control::Attributes::Acl as u32), + Some(access_control::AttributesDiscriminants::Acl as u32), ); let acl_input = AttrData::new( None, @@ -523,11 +421,10 @@ fn write_with_runtime_acl_add() { basic_acl .add_target(Target::new(Some(0), Some(access_control::ID), None)) .unwrap(); - im.acl_mgr.add(basic_acl).unwrap(); + im.matter.acl_mgr.borrow_mut().add(basic_acl).unwrap(); // Test: deny write (with error), then ACL is added, then allow write - handle_write_reqs( - &mut im, + im.handle_write_reqs( peer, None, // write to echo-cluster attribute, write to acl attribute, write to echo-cluster attribute @@ -538,7 +435,7 @@ fn write_with_runtime_acl_add() { AttrStatus::new(&ep0_att, IMStatusCode::Success, 0), ], ); - assert_eq!(AttrValue::Uint16(val0), read_cluster_id_write_attr(&im, 0)); + assert_eq!(val0, im.echo_cluster(0).att_write); } #[test] @@ -551,16 +448,18 @@ fn test_read_data_ver() { // - 2 responses are expected let _ = env_logger::try_init(); let peer = 98765; - let mut im = ImEngine::new(); + let mut mdns = DummyMdns {}; + let matter = matter(&mut mdns); + let mut im = ImEngine::new(&matter); // Add ACL to allow our peer with only OPERATE permission let acl = AclEntry::new(1, Privilege::OPERATE, AuthMode::Case); - im.acl_mgr.add(acl).unwrap(); + im.matter.acl_mgr.borrow_mut().add(acl).unwrap(); let wc_ep_att1 = GenericPath::new( None, Some(echo_cluster::ID), - Some(echo_cluster::Attributes::Att1 as u32), + Some(echo_cluster::AttributesDiscriminants::Att1 as u32), ); let input = &[AttrPath::new(&wc_ep_att1)]; @@ -569,7 +468,7 @@ fn test_read_data_ver() { GenericPath::new( Some(0), Some(echo_cluster::ID), - Some(echo_cluster::Attributes::Att1 as u32) + Some(echo_cluster::AttributesDiscriminants::Att1 as u32) ), ElementType::U16(0x1234) ), @@ -577,7 +476,7 @@ fn test_read_data_ver() { GenericPath::new( Some(1), Some(echo_cluster::ID), - Some(echo_cluster::Attributes::Att1 as u32) + Some(echo_cluster::AttributesDiscriminants::Att1 as u32) ), ElementType::U16(0x1234) ), @@ -585,7 +484,7 @@ fn test_read_data_ver() { let mut out_buf = [0u8; 400]; // Test 1: Simple read to retrieve the current Data Version of Cluster at Endpoint 0 - let received = gen_read_reqs_output(&mut im, peer, input, None, &mut out_buf); + let received = im.gen_read_reqs_output(peer, input, None, &mut out_buf); assert_attr_report(&received, expected); let data_ver_cluster_at_0 = received @@ -607,8 +506,7 @@ fn test_read_data_ver() { }]; // Test 2: Add Dataversion filter for cluster at endpoint 0 only single entry should be retrieved - let received = gen_read_reqs_output( - &mut im, + let received = im.gen_read_reqs_output( peer, input, Some(TLVArray::Slice(&dataver_filter)), @@ -618,7 +516,7 @@ fn test_read_data_ver() { GenericPath::new( Some(1), Some(echo_cluster::ID), - Some(echo_cluster::Attributes::Att1 as u32) + Some(echo_cluster::AttributesDiscriminants::Att1 as u32) ), ElementType::U16(0x1234) )]; @@ -629,11 +527,10 @@ fn test_read_data_ver() { let ep0_att1 = GenericPath::new( Some(0), Some(echo_cluster::ID), - Some(echo_cluster::Attributes::Att1 as u32), + Some(echo_cluster::AttributesDiscriminants::Att1 as u32), ); let input = &[AttrPath::new(&ep0_att1)]; - let received = gen_read_reqs_output( - &mut im, + let received = im.gen_read_reqs_output( peer, input, Some(TLVArray::Slice(&dataver_filter)), @@ -654,21 +551,23 @@ fn test_write_data_ver() { // - 2 responses are expected let _ = env_logger::try_init(); let peer = 98765; - let mut im = ImEngine::new(); + let mut mdns = DummyMdns {}; + let matter = matter(&mut mdns); + let mut im = ImEngine::new(&matter); // Add ACL to allow our peer with only OPERATE permission let acl = AclEntry::new(1, Privilege::ADMIN, AuthMode::Case); - im.acl_mgr.add(acl).unwrap(); + im.matter.acl_mgr.borrow_mut().add(acl).unwrap(); let wc_ep_attwrite = GenericPath::new( None, Some(echo_cluster::ID), - Some(echo_cluster::Attributes::AttWrite as u32), + Some(echo_cluster::AttributesDiscriminants::AttWrite as u32), ); let ep0_attwrite = GenericPath::new( Some(0), Some(echo_cluster::ID), - Some(echo_cluster::Attributes::AttWrite as u32), + Some(echo_cluster::AttributesDiscriminants::AttWrite as u32), ); let val0 = 10u16; @@ -676,7 +575,7 @@ fn test_write_data_ver() { let attr_data0 = EncodeValue::Value(&val0); let attr_data1 = EncodeValue::Value(&val1); - let initial_data_ver = read_cluster_id_data_ver(&im, 0); + let initial_data_ver = im.echo_cluster(0).data_ver.get(); // Test 1: Write with correct dataversion should succeed let input_correct_dataver = &[AttrData::new( @@ -684,14 +583,13 @@ fn test_write_data_ver() { AttrPath::new(&ep0_attwrite), attr_data0, )]; - handle_write_reqs( - &mut im, + im.handle_write_reqs( peer, None, input_correct_dataver, &[AttrStatus::new(&ep0_attwrite, IMStatusCode::Success, 0)], ); - assert_eq!(AttrValue::Uint16(val0), read_cluster_id_write_attr(&im, 0)); + assert_eq!(val0, im.echo_cluster(0).att_write); // Test 2: Write with incorrect dataversion should fail // Now the data version would have incremented due to the previous write @@ -700,8 +598,7 @@ fn test_write_data_ver() { AttrPath::new(&ep0_attwrite), attr_data1, )]; - handle_write_reqs( - &mut im, + im.handle_write_reqs( peer, None, input_correct_dataver, @@ -711,26 +608,25 @@ fn test_write_data_ver() { 0, )], ); - assert_eq!(AttrValue::Uint16(val0), read_cluster_id_write_attr(&im, 0)); + assert_eq!(val0, im.echo_cluster(0).att_write); // Test 3: Wildcard write with incorrect dataversion should ignore that cluster // In this case, while the data version is correct for endpoint 0, the endpoint 1's // data version would not match - let new_data_ver = read_cluster_id_data_ver(&im, 0); + let new_data_ver = im.echo_cluster(0).data_ver.get(); let input_correct_dataver = &[AttrData::new( Some(new_data_ver), AttrPath::new(&wc_ep_attwrite), attr_data1, )]; - handle_write_reqs( - &mut im, + im.handle_write_reqs( peer, None, input_correct_dataver, &[AttrStatus::new(&ep0_attwrite, IMStatusCode::Success, 0)], ); - assert_eq!(AttrValue::Uint16(val1), read_cluster_id_write_attr(&im, 0)); + assert_eq!(val1, im.echo_cluster(0).att_write); assert_eq!(initial_data_ver + 1, new_data_ver); } diff --git a/matter/tests/data_model/attribute_lists.rs b/matter/tests/data_model/attribute_lists.rs index 62e79c4c..ace1f3db 100644 --- a/matter/tests/data_model/attribute_lists.rs +++ b/matter/tests/data_model/attribute_lists.rs @@ -16,36 +16,22 @@ */ use matter::{ - data_model::{core::DataModel, objects::EncodeValue}, + data_model::objects::EncodeValue, interaction_model::{ - core::{IMStatusCode, OpCode}, + core::IMStatusCode, + messages::ib::{AttrData, AttrPath, AttrStatus}, messages::GenericPath, - messages::{ - ib::{AttrData, AttrPath, AttrStatus}, - msg::{WriteReq, WriteResp}, - }, }, - tlv::{self, FromTLV, Nullable}, + mdns::DummyMdns, + tlv::Nullable, }; use crate::common::{ echo_cluster::{self, TestChecker}, - im_engine::im_engine, + im_engine::{matter, ImEngine}, }; // Helper for handling Write Attribute sequences -fn handle_write_reqs(input: &[AttrData], expected: &[AttrStatus]) -> DataModel { - let mut out_buf = [0u8; 400]; - let write_req = WriteReq::new(false, input); - - let (dm, _, out_buf) = im_engine(OpCode::WriteRequest, &write_req, &mut out_buf); - tlv::print_tlv_list(out_buf); - let root = tlv::get_root_node_struct(out_buf).unwrap(); - let resp = WriteResp::from_tlv(&root).unwrap(); - assert_eq!(resp.write_responses, expected); - dm -} - #[test] /// This tests all the attribute list operations /// add item, edit item, delete item, overwrite list, delete list @@ -67,14 +53,14 @@ fn attr_list_ops() { let att_data = GenericPath::new( Some(0), Some(echo_cluster::ID), - Some(echo_cluster::Attributes::AttWriteList as u32), + Some(echo_cluster::AttributesDiscriminants::AttWriteList as u32), ); let mut att_path = AttrPath::new(&att_data); // Test 1: Add Operation - add val0 let input = &[AttrData::new(None, att_path, EncodeValue::Value(&val0))]; let expected = &[AttrStatus::new(&att_data, IMStatusCode::Success, 0)]; - let _ = handle_write_reqs(input, expected); + ImEngine::new_with_write_reqs(&matter(&mut DummyMdns), input, expected); { let tc = tc_handle.lock().unwrap(); @@ -84,7 +70,7 @@ fn attr_list_ops() { // Test 2: Another Add Operation - add val1 let input = &[AttrData::new(None, att_path, EncodeValue::Value(&val1))]; let expected = &[AttrStatus::new(&att_data, IMStatusCode::Success, 0)]; - let _ = handle_write_reqs(input, expected); + ImEngine::new_with_write_reqs(&matter(&mut DummyMdns), input, expected); { let tc = tc_handle.lock().unwrap(); @@ -95,7 +81,7 @@ fn attr_list_ops() { att_path.list_index = Some(Nullable::NotNull(1)); let input = &[AttrData::new(None, att_path, EncodeValue::Value(&val0))]; let expected = &[AttrStatus::new(&att_data, IMStatusCode::Success, 0)]; - let _ = handle_write_reqs(input, expected); + ImEngine::new_with_write_reqs(&matter(&mut DummyMdns), input, expected); { let tc = tc_handle.lock().unwrap(); @@ -106,7 +92,7 @@ fn attr_list_ops() { att_path.list_index = Some(Nullable::NotNull(0)); let input = &[AttrData::new(None, att_path, delete_item)]; let expected = &[AttrStatus::new(&att_data, IMStatusCode::Success, 0)]; - let _ = handle_write_reqs(input, expected); + ImEngine::new_with_write_reqs(&matter(&mut DummyMdns), input, expected); { let tc = tc_handle.lock().unwrap(); @@ -122,7 +108,7 @@ fn attr_list_ops() { EncodeValue::Value(&overwrite_val), )]; let expected = &[AttrStatus::new(&att_data, IMStatusCode::Success, 0)]; - let _ = handle_write_reqs(input, expected); + ImEngine::new_with_write_reqs(&matter(&mut DummyMdns), input, expected); { let tc = tc_handle.lock().unwrap(); @@ -133,7 +119,7 @@ fn attr_list_ops() { att_path.list_index = None; let input = &[AttrData::new(None, att_path, delete_all)]; let expected = &[AttrStatus::new(&att_data, IMStatusCode::Success, 0)]; - let _ = handle_write_reqs(input, expected); + ImEngine::new_with_write_reqs(&matter(&mut DummyMdns), input, expected); { let tc = tc_handle.lock().unwrap(); diff --git a/matter/tests/data_model/attributes.rs b/matter/tests/data_model/attributes.rs index 091b89fe..17e41124 100644 --- a/matter/tests/data_model/attributes.rs +++ b/matter/tests/data_model/attributes.rs @@ -18,53 +18,26 @@ use matter::{ data_model::{ cluster_on_off, - core::DataModel, - objects::{AttrValue, EncodeValue, GlobalElements}, + objects::{EncodeValue, GlobalElements}, }, interaction_model::{ - core::{IMStatusCode, OpCode}, + core::IMStatusCode, + messages::ib::{AttrData, AttrPath, AttrResp, AttrStatus}, messages::GenericPath, - messages::{ - ib::{AttrData, AttrPath, AttrResp, AttrStatus}, - msg::{ReadReq, ReportDataMsg, WriteReq, WriteResp}, - }, }, - tlv::{self, ElementType, FromTLV, TLVElement, TLVWriter, TagType}, + mdns::DummyMdns, + tlv::{ElementType, TLVElement, TLVWriter, TagType}, }; use crate::{ attr_data, attr_data_path, attr_status, - common::{attributes::*, echo_cluster, im_engine::im_engine}, + common::{ + attributes::*, + echo_cluster, + im_engine::{matter, ImEngine}, + }, }; -fn handle_read_reqs(input: &[AttrPath], expected: &[AttrResp]) { - let mut out_buf = [0u8; 400]; - let received = gen_read_reqs_output(input, &mut out_buf); - assert_attr_report(&received, expected) -} - -// Helper for handling Read Req sequences -fn gen_read_reqs_output<'a>(input: &[AttrPath], out_buf: &'a mut [u8]) -> ReportDataMsg<'a> { - let read_req = ReadReq::new(true).set_attr_requests(input); - let (_, _, out_buf) = im_engine(OpCode::ReadRequest, &read_req, out_buf); - tlv::print_tlv_list(out_buf); - let root = tlv::get_root_node_struct(out_buf).unwrap(); - ReportDataMsg::from_tlv(&root).unwrap() -} - -// Helper for handling Write Attribute sequences -fn handle_write_reqs(input: &[AttrData], expected: &[AttrStatus]) -> DataModel { - let mut out_buf = [0u8; 400]; - let write_req = WriteReq::new(false, input); - - let (dm, _, out_buf) = im_engine(OpCode::WriteRequest, &write_req, &mut out_buf); - let root = tlv::get_root_node_struct(out_buf).unwrap(); - let response = WriteResp::from_tlv(&root).unwrap(); - assert_eq!(response.write_responses, expected); - - dm -} - #[test] fn test_read_success() { // 3 Attr Read Requests @@ -76,17 +49,17 @@ fn test_read_success() { let ep0_att1 = GenericPath::new( Some(0), Some(echo_cluster::ID), - Some(echo_cluster::Attributes::Att1 as u32), + Some(echo_cluster::AttributesDiscriminants::Att1 as u32), ); let ep1_att2 = GenericPath::new( Some(1), Some(echo_cluster::ID), - Some(echo_cluster::Attributes::Att2 as u32), + Some(echo_cluster::AttributesDiscriminants::Att2 as u32), ); let ep1_attcustom = GenericPath::new( Some(1), Some(echo_cluster::ID), - Some(echo_cluster::Attributes::AttCustom as u32), + Some(echo_cluster::AttributesDiscriminants::AttCustom as u32), ); let input = &[ AttrPath::new(&ep0_att1), @@ -101,7 +74,7 @@ fn test_read_success() { ElementType::U32(echo_cluster::ATTR_CUSTOM_VALUE) ), ]; - handle_read_reqs(input, expected); + ImEngine::new_with_read_reqs(&matter(&mut DummyMdns), input, expected); } #[test] @@ -118,17 +91,17 @@ fn test_read_unsupported_fields() { let invalid_endpoint = GenericPath::new( Some(2), Some(echo_cluster::ID), - Some(echo_cluster::Attributes::Att1 as u32), + Some(echo_cluster::AttributesDiscriminants::Att1 as u32), ); let invalid_cluster = GenericPath::new( Some(0), Some(0x1234), - Some(echo_cluster::Attributes::Att1 as u32), + Some(echo_cluster::AttributesDiscriminants::Att1 as u32), ); let invalid_cluster_wc_endpoint = GenericPath::new( None, Some(0x1234), - Some(echo_cluster::Attributes::AttCustom as u32), + Some(echo_cluster::AttributesDiscriminants::AttCustom as u32), ); let invalid_attribute = GenericPath::new(Some(0), Some(echo_cluster::ID), Some(0x1234)); let invalid_attribute_wc_endpoint = @@ -148,7 +121,7 @@ fn test_read_unsupported_fields() { attr_status!(&invalid_cluster, IMStatusCode::UnsupportedCluster), attr_status!(&invalid_attribute, IMStatusCode::UnsupportedAttribute), ]; - handle_read_reqs(input, expected); + ImEngine::new_with_read_reqs(&matter(&mut DummyMdns), input, expected); } #[test] @@ -161,7 +134,7 @@ fn test_read_wc_endpoint_all_have_clusters() { let wc_ep_att1 = GenericPath::new( None, Some(echo_cluster::ID), - Some(echo_cluster::Attributes::Att1 as u32), + Some(echo_cluster::AttributesDiscriminants::Att1 as u32), ); let input = &[AttrPath::new(&wc_ep_att1)]; @@ -169,17 +142,17 @@ fn test_read_wc_endpoint_all_have_clusters() { attr_data!( 0, echo_cluster::ID, - echo_cluster::Attributes::Att1, + echo_cluster::AttributesDiscriminants::Att1, ElementType::U16(0x1234) ), attr_data!( 1, echo_cluster::ID, - echo_cluster::Attributes::Att1, + echo_cluster::AttributesDiscriminants::Att1, ElementType::U16(0x1234) ), ]; - handle_read_reqs(input, expected); + ImEngine::new_with_read_reqs(&matter(&mut DummyMdns), input, expected); } #[test] @@ -192,7 +165,7 @@ fn test_read_wc_endpoint_only_1_has_cluster() { let wc_ep_onoff = GenericPath::new( None, Some(cluster_on_off::ID), - Some(cluster_on_off::Attributes::OnOff as u32), + Some(cluster_on_off::AttributesDiscriminants::OnOff as u32), ); let input = &[AttrPath::new(&wc_ep_onoff)]; @@ -200,11 +173,11 @@ fn test_read_wc_endpoint_only_1_has_cluster() { GenericPath::new( Some(1), Some(cluster_on_off::ID), - Some(cluster_on_off::Attributes::OnOff as u32) + Some(cluster_on_off::AttributesDiscriminants::OnOff as u32) ), ElementType::False )]; - handle_read_reqs(input, expected); + ImEngine::new_with_read_reqs(&matter(&mut DummyMdns), input, expected); } #[test] @@ -221,10 +194,10 @@ fn test_read_wc_endpoint_wc_attribute() { &[ GlobalElements::FeatureMap as u16, GlobalElements::AttributeList as u16, - echo_cluster::Attributes::Att1 as u16, - echo_cluster::Attributes::Att2 as u16, - echo_cluster::Attributes::AttWrite as u16, - echo_cluster::Attributes::AttCustom as u16, + echo_cluster::AttributesDiscriminants::Att1 as u16, + echo_cluster::AttributesDiscriminants::Att2 as u16, + echo_cluster::AttributesDiscriminants::AttWrite as u16, + echo_cluster::AttributesDiscriminants::AttCustom as u16, ], ); let attr_list_tlv = attr_list.to_tlv(); @@ -250,7 +223,7 @@ fn test_read_wc_endpoint_wc_attribute() { GenericPath::new( Some(0), Some(echo_cluster::ID), - Some(echo_cluster::Attributes::Att1 as u32), + Some(echo_cluster::AttributesDiscriminants::Att1 as u32), ), ElementType::U16(0x1234) ), @@ -258,7 +231,7 @@ fn test_read_wc_endpoint_wc_attribute() { GenericPath::new( Some(0), Some(echo_cluster::ID), - Some(echo_cluster::Attributes::Att2 as u32), + Some(echo_cluster::AttributesDiscriminants::Att2 as u32), ), ElementType::U16(0x5678) ), @@ -266,7 +239,7 @@ fn test_read_wc_endpoint_wc_attribute() { GenericPath::new( Some(0), Some(echo_cluster::ID), - Some(echo_cluster::Attributes::AttCustom as u32), + Some(echo_cluster::AttributesDiscriminants::AttCustom as u32), ), ElementType::U32(echo_cluster::ATTR_CUSTOM_VALUE) ), @@ -290,7 +263,7 @@ fn test_read_wc_endpoint_wc_attribute() { GenericPath::new( Some(1), Some(echo_cluster::ID), - Some(echo_cluster::Attributes::Att1 as u32), + Some(echo_cluster::AttributesDiscriminants::Att1 as u32), ), ElementType::U16(0x1234) ), @@ -298,7 +271,7 @@ fn test_read_wc_endpoint_wc_attribute() { GenericPath::new( Some(1), Some(echo_cluster::ID), - Some(echo_cluster::Attributes::Att2 as u32), + Some(echo_cluster::AttributesDiscriminants::Att2 as u32), ), ElementType::U16(0x5678) ), @@ -306,12 +279,12 @@ fn test_read_wc_endpoint_wc_attribute() { GenericPath::new( Some(1), Some(echo_cluster::ID), - Some(echo_cluster::Attributes::AttCustom as u32), + Some(echo_cluster::AttributesDiscriminants::AttCustom as u32), ), ElementType::U32(echo_cluster::ATTR_CUSTOM_VALUE) ), ]; - handle_read_reqs(input, expected); + ImEngine::new_with_read_reqs(&matter(&mut DummyMdns), input, expected); } #[test] @@ -332,12 +305,12 @@ fn test_write_success() { let ep0_att = GenericPath::new( Some(0), Some(echo_cluster::ID), - Some(echo_cluster::Attributes::AttWrite as u32), + Some(echo_cluster::AttributesDiscriminants::AttWrite as u32), ); let ep1_att = GenericPath::new( Some(1), Some(echo_cluster::ID), - Some(echo_cluster::Attributes::AttWrite as u32), + Some(echo_cluster::AttributesDiscriminants::AttWrite as u32), ); let input = &[ @@ -357,24 +330,11 @@ fn test_write_success() { AttrStatus::new(&ep1_att, IMStatusCode::Success, 0), ]; - let dm = handle_write_reqs(input, expected); - let node = dm.node.read().unwrap(); - let echo = node.get_cluster(0, echo_cluster::ID).unwrap(); - assert_eq!( - AttrValue::Uint16(val0), - *echo - .base() - .read_attribute_raw(echo_cluster::Attributes::AttWrite as u16) - .unwrap() - ); - let echo = node.get_cluster(1, echo_cluster::ID).unwrap(); - assert_eq!( - AttrValue::Uint16(val1), - *echo - .base() - .read_attribute_raw(echo_cluster::Attributes::AttWrite as u16) - .unwrap() - ); + let mut mdns = DummyMdns; + let matter = matter(&mut mdns); + let im = ImEngine::new_with_write_reqs(&matter, input, expected); + assert_eq!(val0, im.echo_cluster(0).att_write); + assert_eq!(val1, im.echo_cluster(1).att_write); } #[test] @@ -390,7 +350,7 @@ fn test_write_wc_endpoint() { let ep_att = GenericPath::new( None, Some(echo_cluster::ID), - Some(echo_cluster::Attributes::AttWrite as u32), + Some(echo_cluster::AttributesDiscriminants::AttWrite as u32), ); let input = &[AttrData::new( None, @@ -401,38 +361,23 @@ fn test_write_wc_endpoint() { let ep0_att = GenericPath::new( Some(0), Some(echo_cluster::ID), - Some(echo_cluster::Attributes::AttWrite as u32), + Some(echo_cluster::AttributesDiscriminants::AttWrite as u32), ); let ep1_att = GenericPath::new( Some(1), Some(echo_cluster::ID), - Some(echo_cluster::Attributes::AttWrite as u32), + Some(echo_cluster::AttributesDiscriminants::AttWrite as u32), ); let expected = &[ AttrStatus::new(&ep0_att, IMStatusCode::Success, 0), AttrStatus::new(&ep1_att, IMStatusCode::Success, 0), ]; - let dm = handle_write_reqs(input, expected); - assert_eq!( - AttrValue::Uint16(val0), - dm.read_attribute_raw( - 0, - echo_cluster::ID, - echo_cluster::Attributes::AttWrite as u16 - ) - .unwrap() - ); - assert_eq!( - AttrValue::Uint16(val0), - dm.read_attribute_raw( - 0, - echo_cluster::ID, - echo_cluster::Attributes::AttWrite as u16 - ) - .unwrap() - ); + let mut mdns = DummyMdns; + let matter = matter(&mut mdns); + let im = ImEngine::new_with_write_reqs(&matter, input, expected); + assert_eq!(val0, im.echo_cluster(0).att_write); } #[test] @@ -455,25 +400,25 @@ fn test_write_unsupported_fields() { let invalid_endpoint = GenericPath::new( Some(4), Some(echo_cluster::ID), - Some(echo_cluster::Attributes::AttWrite as u32), + Some(echo_cluster::AttributesDiscriminants::AttWrite as u32), ); let invalid_cluster = GenericPath::new( Some(0), Some(0x1234), - Some(echo_cluster::Attributes::AttWrite as u32), + Some(echo_cluster::AttributesDiscriminants::AttWrite as u32), ); let invalid_attribute = GenericPath::new(Some(0), Some(echo_cluster::ID), Some(0x1234)); let wc_endpoint_invalid_cluster = GenericPath::new( None, Some(0x1234), - Some(echo_cluster::Attributes::AttWrite as u32), + Some(echo_cluster::AttributesDiscriminants::AttWrite as u32), ); let wc_endpoint_invalid_attribute = GenericPath::new(None, Some(echo_cluster::ID), Some(0x1234)); let wc_cluster = GenericPath::new( Some(0), None, - Some(echo_cluster::Attributes::AttWrite as u32), + Some(echo_cluster::AttributesDiscriminants::AttWrite as u32), ); let wc_attribute = GenericPath::new(Some(0), Some(echo_cluster::ID), None); @@ -521,14 +466,11 @@ fn test_write_unsupported_fields() { AttrStatus::new(&wc_cluster, IMStatusCode::UnsupportedCluster, 0), AttrStatus::new(&wc_attribute, IMStatusCode::UnsupportedAttribute, 0), ]; - let dm = handle_write_reqs(input, expected); + let mut mdns = DummyMdns; + let matter = matter(&mut mdns); + let im = ImEngine::new_with_write_reqs(&matter, input, expected); assert_eq!( - AttrValue::Uint16(echo_cluster::ATTR_WRITE_DEFAULT_VALUE), - dm.read_attribute_raw( - 0, - echo_cluster::ID, - echo_cluster::Attributes::AttWrite as u16 - ) - .unwrap() + echo_cluster::ATTR_WRITE_DEFAULT_VALUE, + im.echo_cluster(0).att_write ); } diff --git a/matter/tests/data_model/commands.rs b/matter/tests/data_model/commands.rs index 353b6627..50c1a8a3 100644 --- a/matter/tests/data_model/commands.rs +++ b/matter/tests/data_model/commands.rs @@ -17,39 +17,23 @@ use crate::{ cmd_data, - common::{commands::*, echo_cluster, im_engine::im_engine}, + common::{ + commands::*, + echo_cluster, + im_engine::{matter, ImEngine}, + }, echo_req, echo_resp, }; use matter::{ data_model::{cluster_on_off, objects::EncodeValue}, interaction_model::{ - core::{IMStatusCode, OpCode}, - messages::{ - ib::{CmdData, CmdPath, CmdStatus}, - msg, - msg::InvReq, - }, + core::IMStatusCode, + messages::ib::{CmdData, CmdPath, CmdStatus}, }, - tlv::{self, FromTLV, TLVArray}, + mdns::DummyMdns, }; -// Helper for handling Invoke Command sequences -fn handle_commands(input: &[CmdData], expected: &[ExpectedInvResp]) { - let mut out_buf = [0u8; 400]; - let req = InvReq { - suppress_response: Some(false), - timed_request: Some(false), - inv_requests: Some(TLVArray::Slice(input)), - }; - - let (_, _, out_buf) = im_engine(OpCode::InvokeRequest, &req, &mut out_buf); - tlv::print_tlv_list(out_buf); - let root = tlv::get_root_node_struct(out_buf).unwrap(); - let resp = msg::InvResp::from_tlv(&root).unwrap(); - assert_inv_response(&resp, expected) -} - #[test] fn test_invoke_cmds_success() { // 2 echo Requests @@ -59,7 +43,7 @@ fn test_invoke_cmds_success() { let input = &[echo_req!(0, 5), echo_req!(1, 10)]; let expected = &[echo_resp!(0, 10), echo_resp!(1, 30)]; - handle_commands(input, expected); + ImEngine::new_with_commands(&matter(&mut DummyMdns), input, expected); } #[test] @@ -75,17 +59,17 @@ fn test_invoke_cmds_unsupported_fields() { let invalid_endpoint = CmdPath::new( Some(2), Some(echo_cluster::ID), - Some(echo_cluster::Commands::EchoReq as u16), + Some(echo_cluster::Commands::EchoReq as u32), ); let invalid_cluster = CmdPath::new( Some(0), Some(0x1234), - Some(echo_cluster::Commands::EchoReq as u16), + Some(echo_cluster::Commands::EchoReq as u32), ); let invalid_cluster_wc_endpoint = CmdPath::new( None, Some(0x1234), - Some(echo_cluster::Commands::EchoReq as u16), + Some(echo_cluster::Commands::EchoReq as u32), ); let invalid_command = CmdPath::new(Some(0), Some(echo_cluster::ID), Some(0x1234)); let invalid_command_wc_endpoint = CmdPath::new(None, Some(echo_cluster::ID), Some(0x1234)); @@ -114,7 +98,7 @@ fn test_invoke_cmds_unsupported_fields() { 0, )), ]; - handle_commands(input, expected); + ImEngine::new_with_commands(&matter(&mut DummyMdns), input, expected); } #[test] @@ -125,11 +109,11 @@ fn test_invoke_cmd_wc_endpoint_all_have_clusters() { let path = CmdPath::new( None, Some(echo_cluster::ID), - Some(echo_cluster::Commands::EchoReq as u16), + Some(echo_cluster::Commands::EchoReq as u32), ); let input = &[cmd_data!(path, 5)]; let expected = &[echo_resp!(0, 10), echo_resp!(1, 15)]; - handle_commands(input, expected); + ImEngine::new_with_commands(&matter(&mut DummyMdns), input, expected); } #[test] @@ -141,12 +125,12 @@ fn test_invoke_cmd_wc_endpoint_only_1_has_cluster() { let target = CmdPath::new( None, Some(cluster_on_off::ID), - Some(cluster_on_off::Commands::On as u16), + Some(cluster_on_off::CommandsDiscriminants::On as u32), ); let expected_path = CmdPath::new( Some(1), Some(cluster_on_off::ID), - Some(cluster_on_off::Commands::On as u16), + Some(cluster_on_off::CommandsDiscriminants::On as u32), ); let input = &[cmd_data!(target, 1)]; let expected = &[ExpectedInvResp::Status(CmdStatus::new( @@ -154,5 +138,5 @@ fn test_invoke_cmd_wc_endpoint_only_1_has_cluster() { IMStatusCode::Success, 0, ))]; - handle_commands(input, expected); + ImEngine::new_with_commands(&matter(&mut DummyMdns), input, expected); } diff --git a/matter/tests/data_model/timed_requests.rs b/matter/tests/data_model/timed_requests.rs index 1c4a9414..cf5ddbd7 100644 --- a/matter/tests/data_model/timed_requests.rs +++ b/matter/tests/data_model/timed_requests.rs @@ -15,112 +15,27 @@ * limitations under the License. */ -use core::time; -use std::thread; - use matter::{ - data_model::{ - core::DataModel, - objects::{AttrValue, EncodeValue}, - }, + data_model::objects::EncodeValue, interaction_model::{ - core::{IMStatusCode, OpCode}, - messages::{ib::CmdData, ib::CmdPath, msg::InvReq, GenericPath}, - messages::{ - ib::{AttrData, AttrPath, AttrStatus}, - msg::{self, StatusResp, TimedReq, WriteReq, WriteResp}, - }, + core::IMStatusCode, + messages::ib::{AttrData, AttrPath, AttrStatus}, + messages::{ib::CmdData, ib::CmdPath, GenericPath}, }, - tlv::{self, FromTLV, TLVArray, TLVWriter, ToTLV}, - transport::exchange::{self, Exchange}, + mdns::DummyMdns, + tlv::TLVWriter, }; use crate::{ common::{ commands::*, echo_cluster, - im_engine::{ImEngine, ImInput}, + handlers::{TimedInvResponse, WriteResponse}, + im_engine::{matter, ImEngine}, }, echo_req, echo_resp, }; -fn handle_timed_reqs<'a>( - opcode: OpCode, - request: &dyn ToTLV, - timeout: u16, - delay: u16, - output: &'a mut [u8], -) -> (u8, DataModel, &'a [u8]) { - let mut im_engine = ImEngine::new(); - // Use the same exchange for all parts of the transaction - im_engine.exch = Some(Exchange::new(1, 0, exchange::Role::Responder)); - - if timeout != 0 { - // Send Timed Req - let mut tmp_buf = [0u8; 400]; - let timed_req = TimedReq { timeout }; - let im_input = ImInput::new(OpCode::TimedRequest, &timed_req); - let (_, out_buf) = im_engine.process(&im_input, &mut tmp_buf); - tlv::print_tlv_list(out_buf); - } else { - println!("Skipping timed request"); - } - - // Process any delays - let delay = time::Duration::from_millis(delay.into()); - thread::sleep(delay); - - // Send Write Req - let input = ImInput::new(opcode, request); - let (resp_opcode, output) = im_engine.process(&input, output); - (resp_opcode, im_engine.dm, output) -} -enum WriteResponse<'a> { - TransactionError, - TransactionSuccess(&'a [AttrStatus]), -} - -// Helper for handling Write Attribute sequences -fn handle_timed_write_reqs( - input: &[AttrData], - expected: &WriteResponse, - timeout: u16, - delay: u16, -) -> DataModel { - let mut out_buf = [0u8; 400]; - let write_req = WriteReq::new(false, input); - - let (resp_opcode, dm, out_buf) = handle_timed_reqs( - OpCode::WriteRequest, - &write_req, - timeout, - delay, - &mut out_buf, - ); - tlv::print_tlv_list(out_buf); - let root = tlv::get_root_node_struct(out_buf).unwrap(); - - match expected { - WriteResponse::TransactionSuccess(t) => { - assert_eq!( - num::FromPrimitive::from_u8(resp_opcode), - Some(OpCode::WriteResponse) - ); - let resp = WriteResp::from_tlv(&root).unwrap(); - assert_eq!(resp.write_responses, t); - } - WriteResponse::TransactionError => { - assert_eq!( - num::FromPrimitive::from_u8(resp_opcode), - Some(OpCode::StatusResponse) - ); - let status_resp = StatusResp::from_tlv(&root).unwrap(); - assert_eq!(status_resp.status, IMStatusCode::Timeout); - } - } - dm -} - #[test] fn test_timed_write_fail_and_success() { // - 1 Timed Attr Write Transaction should fail due to timeout @@ -134,7 +49,7 @@ fn test_timed_write_fail_and_success() { let ep_att = GenericPath::new( None, Some(echo_cluster::ID), - Some(echo_cluster::Attributes::AttWrite as u32), + Some(echo_cluster::AttributesDiscriminants::AttWrite as u32), ); let input = &[AttrData::new( None, @@ -145,13 +60,13 @@ fn test_timed_write_fail_and_success() { let ep0_att = GenericPath::new( Some(0), Some(echo_cluster::ID), - Some(echo_cluster::Attributes::AttWrite as u32), + Some(echo_cluster::AttributesDiscriminants::AttWrite as u32), ); let ep1_att = GenericPath::new( Some(1), Some(echo_cluster::ID), - Some(echo_cluster::Attributes::AttWrite as u32), + Some(echo_cluster::AttributesDiscriminants::AttWrite as u32), ); let expected = &[ AttrStatus::new(&ep0_att, IMStatusCode::Success, 0), @@ -159,73 +74,25 @@ fn test_timed_write_fail_and_success() { ]; // Test with incorrect handling - handle_timed_write_reqs(input, &WriteResponse::TransactionError, 400, 500); + ImEngine::new_with_timed_write_reqs( + &matter(&mut DummyMdns), + input, + &WriteResponse::TransactionError, + 400, + 500, + ); // Test with correct handling - let dm = handle_timed_write_reqs(input, &WriteResponse::TransactionSuccess(expected), 400, 0); - assert_eq!( - AttrValue::Uint16(val0), - dm.read_attribute_raw( - 0, - echo_cluster::ID, - echo_cluster::Attributes::AttWrite as u16 - ) - .unwrap() - ); - assert_eq!( - AttrValue::Uint16(val0), - dm.read_attribute_raw( - 0, - echo_cluster::ID, - echo_cluster::Attributes::AttWrite as u16 - ) - .unwrap() + let mut mdns = DummyMdns; + let matter = matter(&mut mdns); + let im = ImEngine::new_with_timed_write_reqs( + &matter, + input, + &WriteResponse::TransactionSuccess(expected), + 400, + 0, ); -} - -enum TimedInvResponse<'a> { - TransactionError(IMStatusCode), - TransactionSuccess(&'a [ExpectedInvResp]), -} -// Helper for handling Invoke Command sequences -fn handle_timed_commands( - input: &[CmdData], - expected: &TimedInvResponse, - timeout: u16, - delay: u16, - set_timed_request: bool, -) -> DataModel { - let mut out_buf = [0u8; 400]; - let req = InvReq { - suppress_response: Some(false), - timed_request: Some(set_timed_request), - inv_requests: Some(TLVArray::Slice(input)), - }; - - let (resp_opcode, dm, out_buf) = - handle_timed_reqs(OpCode::InvokeRequest, &req, timeout, delay, &mut out_buf); - tlv::print_tlv_list(out_buf); - let root = tlv::get_root_node_struct(out_buf).unwrap(); - - match expected { - TimedInvResponse::TransactionSuccess(t) => { - assert_eq!( - num::FromPrimitive::from_u8(resp_opcode), - Some(OpCode::InvokeResponse) - ); - let resp = msg::InvResp::from_tlv(&root).unwrap(); - assert_inv_response(&resp, t) - } - TimedInvResponse::TransactionError(e) => { - assert_eq!( - num::FromPrimitive::from_u8(resp_opcode), - Some(OpCode::StatusResponse) - ); - let status_resp = StatusResp::from_tlv(&root).unwrap(); - assert_eq!(status_resp.status, *e); - } - } - dm + assert_eq!(val0, im.echo_cluster(0).att_write); } #[test] @@ -235,7 +102,8 @@ fn test_timed_cmd_success() { let input = &[echo_req!(0, 5), echo_req!(1, 10)]; let expected = &[echo_resp!(0, 10), echo_resp!(1, 30)]; - handle_timed_commands( + ImEngine::new_with_timed_commands( + &matter(&mut DummyMdns), input, &TimedInvResponse::TransactionSuccess(expected), 400, @@ -250,7 +118,8 @@ fn test_timed_cmd_timeout() { let _ = env_logger::try_init(); let input = &[echo_req!(0, 5), echo_req!(1, 10)]; - handle_timed_commands( + ImEngine::new_with_timed_commands( + &matter(&mut DummyMdns), input, &TimedInvResponse::TransactionError(IMStatusCode::Timeout), 400, @@ -265,7 +134,8 @@ fn test_timed_cmd_timedout_mismatch() { let _ = env_logger::try_init(); let input = &[echo_req!(0, 5), echo_req!(1, 10)]; - handle_timed_commands( + ImEngine::new_with_timed_commands( + &matter(&mut DummyMdns), input, &TimedInvResponse::TransactionError(IMStatusCode::TimedRequestMisMatch), 400, @@ -274,7 +144,8 @@ fn test_timed_cmd_timedout_mismatch() { ); let input = &[echo_req!(0, 5), echo_req!(1, 10)]; - handle_timed_commands( + ImEngine::new_with_timed_commands( + &matter(&mut DummyMdns), input, &TimedInvResponse::TransactionError(IMStatusCode::TimedRequestMisMatch), 0, diff --git a/matter/tests/data_model_tests.rs b/matter/tests/data_model_tests.rs index 392909fa..803c4c5e 100644 --- a/matter/tests/data_model_tests.rs +++ b/matter/tests/data_model_tests.rs @@ -22,6 +22,6 @@ mod data_model { mod attribute_lists; mod attributes; mod commands; - mod long_reads; + // TODO mod long_reads; mod timed_requests; } diff --git a/matter/tests/interaction_model.rs b/matter/tests/interaction_model.rs index 5e88f8af..07d114e9 100644 --- a/matter/tests/interaction_model.rs +++ b/matter/tests/interaction_model.rs @@ -15,116 +15,70 @@ * limitations under the License. */ -use boxslab::Slab; +use matter::data_model::core::DataHandler; use matter::error::Error; +use matter::interaction_model::core::Interaction; +use matter::interaction_model::core::InteractionModel; use matter::interaction_model::core::OpCode; -use matter::interaction_model::messages::msg::InvReq; -use matter::interaction_model::messages::msg::WriteReq; -use matter::interaction_model::InteractionConsumer; -use matter::interaction_model::InteractionModel; -use matter::interaction_model::Transaction; -use matter::tlv::TLVWriter; +use matter::interaction_model::core::Transaction; use matter::transport::exchange::Exchange; use matter::transport::exchange::ExchangeCtx; use matter::transport::network::Address; use matter::transport::packet::Packet; -use matter::transport::packet::PacketPool; -use matter::transport::proto_demux::HandleProto; -use matter::transport::proto_demux::ProtoCtx; -use matter::transport::proto_demux::ResponseRequired; +use matter::transport::proto_ctx::ProtoCtx; use matter::transport::session::SessionMgr; +use matter::utils::epoch::dummy_epoch; +use matter::utils::rand::dummy_rand; use std::net::Ipv4Addr; use std::net::SocketAddr; -use std::sync::{Arc, Mutex}; struct Node { pub endpoint: u16, pub cluster: u32, - pub command: u32, + pub command: u16, pub variable: u8, } struct DataModel { - node: Arc>, + node: Node, } impl DataModel { pub fn new(node: Node) -> Self { - DataModel { - node: Arc::new(Mutex::new(node)), - } + DataModel { node } } } -impl Clone for DataModel { - fn clone(&self) -> Self { - Self { - node: self.node.clone(), - } - } -} - -impl InteractionConsumer for DataModel { - fn consume_invoke_cmd( - &self, - inv_req_msg: &InvReq, - _trans: &mut Transaction, - _tlvwriter: &mut TLVWriter, - ) -> Result<(), Error> { - if let Some(inv_requests) = &inv_req_msg.inv_requests { - for i in inv_requests.iter() { - let data = if let Some(data) = i.data.unwrap_tlv() { - data - } else { - continue; - }; - let cmd_path_ib = i.path; - let mut common_data = self.node.lock().unwrap(); - common_data.endpoint = cmd_path_ib.path.endpoint.unwrap_or(1); - common_data.cluster = cmd_path_ib.path.cluster.unwrap_or(0); - common_data.command = cmd_path_ib.path.leaf.unwrap_or(0); - data.confirm_struct().unwrap(); - common_data.variable = data.find_tag(0).unwrap().u8().unwrap(); +impl DataHandler for DataModel { + fn handle( + &mut self, + interaction: &Interaction, + _tx: &mut Packet, + _transaction: &mut Transaction, + ) -> Result { + match interaction { + Interaction::Invoke(req) => { + if let Some(inv_requests) = &req.inv_requests { + for i in inv_requests.iter() { + let data = if let Some(data) = i.data.unwrap_tlv() { + data + } else { + continue; + }; + let cmd_path_ib = i.path; + let mut common_data = &mut self.node; + common_data.endpoint = cmd_path_ib.path.endpoint.unwrap_or(1); + common_data.cluster = cmd_path_ib.path.cluster.unwrap_or(0); + common_data.command = cmd_path_ib.path.leaf.unwrap_or(0) as u16; + data.confirm_struct().unwrap(); + common_data.variable = data.find_tag(0).unwrap().u8().unwrap(); + } + } } + _ => (), } - Ok(()) - } - - fn consume_read_attr( - &self, - _req: &[u8], - _trans: &mut Transaction, - _tlvwriter: &mut TLVWriter, - ) -> Result<(), Error> { - Ok(()) - } - - fn consume_write_attr( - &self, - _req: &WriteReq, - _trans: &mut Transaction, - _tlvwriter: &mut TLVWriter, - ) -> Result<(), Error> { - Ok(()) - } - - fn consume_status_report( - &self, - _req: &matter::interaction_model::messages::msg::StatusResp, - _trans: &mut Transaction, - _tw: &mut TLVWriter, - ) -> Result<(OpCode, ResponseRequired), Error> { - Ok((OpCode::StatusResponse, ResponseRequired::No)) - } - - fn consume_subscribe( - &self, - _req: &[u8], - _trans: &mut Transaction, - _tw: &mut TLVWriter, - ) -> Result<(OpCode, matter::transport::proto_demux::ResponseRequired), Error> { - Ok((OpCode::StatusResponse, ResponseRequired::No)) + Ok(false) } } @@ -135,9 +89,9 @@ fn handle_data(action: OpCode, data_in: &[u8], data_out: &mut [u8]) -> (DataMode command: 0, variable: 0, }); - let mut interaction_model = InteractionModel::new(Box::new(data_model.clone())); + let mut interaction_model = InteractionModel(data_model); let mut exch: Exchange = Default::default(); - let mut sess_mgr: SessionMgr = Default::default(); + let mut sess_mgr = SessionMgr::new(dummy_epoch, dummy_rand); let sess_idx = sess_mgr .get_or_add( 0, @@ -153,24 +107,27 @@ fn handle_data(action: OpCode, data_in: &[u8], data_out: &mut [u8]) -> (DataMode let exch_ctx = ExchangeCtx { exch: &mut exch, sess, + epoch: dummy_epoch, }; - let mut rx = Slab::::try_new(Packet::new_rx().unwrap()).unwrap(); - let tx = Slab::::try_new(Packet::new_tx().unwrap()).unwrap(); + let mut rx_buf = [0; 1500]; + let mut tx_buf = [0; 1500]; + let mut rx = Packet::new_rx(&mut rx_buf); + let mut tx = Packet::new_tx(&mut tx_buf); // Create fake rx packet rx.set_proto_id(0x01); rx.set_proto_opcode(action as u8); rx.peer = Address::default(); let in_data_len = data_in.len(); - let rx_buf = rx.as_borrow_slice(); + let rx_buf = rx.as_mut_slice(); rx_buf[..in_data_len].copy_from_slice(data_in); - let mut ctx = ProtoCtx::new(exch_ctx, rx, tx); + let mut ctx = ProtoCtx::new(exch_ctx, &rx, &mut tx); - interaction_model.handle_proto_id(&mut ctx).unwrap(); + interaction_model.handle(&mut ctx).unwrap(); - let out_len = ctx.tx.as_borrow_slice().len(); - data_out[..out_len].copy_from_slice(ctx.tx.as_borrow_slice()); - (data_model, out_len) + let out_len = ctx.tx.as_mut_slice().len(); + data_out[..out_len].copy_from_slice(ctx.tx.as_mut_slice()); + (interaction_model.0, out_len) } #[test] @@ -186,7 +143,7 @@ fn test_valid_invoke_cmd() -> Result<(), Error> { let mut out_buf: [u8; 20] = [0; 20]; let (data_model, _) = handle_data(OpCode::InvokeRequest, &b, &mut out_buf); - let data = data_model.node.lock().unwrap(); + let data = &data_model.node; assert_eq!(data.endpoint, 0); assert_eq!(data.cluster, 49); assert_eq!(data.command, 12); From 817d55aecc400a321eaed0beb0de0b6cb224eefc Mon Sep 17 00:00:00 2001 From: ivmarkov Date: Fri, 21 Apr 2023 12:00:17 +0000 Subject: [PATCH 02/72] Start reintroducing long reads and subscriptions from mainline --- matter/src/data_model/core.rs | 11 +++++ matter/src/data_model/objects/dataver.rs | 2 +- matter/src/data_model/objects/node.rs | 47 +++++++++++++++--- matter/src/interaction_model/core.rs | 63 +++++++++++++++++++++--- 4 files changed, 107 insertions(+), 16 deletions(-) diff --git a/matter/src/data_model/core.rs b/matter/src/data_model/core.rs index 005871d5..0ca83792 100644 --- a/matter/src/data_model/core.rs +++ b/matter/src/data_model/core.rs @@ -69,6 +69,17 @@ impl<'a, T> DataModel<'a, T> { CmdDataEncoder::handle(item, &mut self.handler, transaction, &mut tw)?; } } + Interaction::Subscribe(req) => { + for item in self.node.subscribing_read(req, &accessor) { + AttrDataEncoder::handle_read(item, &self.handler, &mut tw)?; + } + } + Interaction::Status(_resp) => { + todo!() + // for item in self.node.subscribing_read(req, &accessor) { + // AttrDataEncoder::handle_read(item, &self.handler, &mut tw)?; + // } + } Interaction::Timed(_) => (), } diff --git a/matter/src/data_model/objects/dataver.rs b/matter/src/data_model/objects/dataver.rs index fc062be0..f05a3838 100644 --- a/matter/src/data_model/objects/dataver.rs +++ b/matter/src/data_model/objects/dataver.rs @@ -15,7 +15,7 @@ * limitations under the License. */ - use crate::utils::rand::Rand; +use crate::utils::rand::Rand; pub struct Dataver { ver: u32, diff --git a/matter/src/data_model/objects/node.rs b/matter/src/data_model/objects/node.rs index 2eb11754..4ec1765c 100644 --- a/matter/src/data_model/objects/node.rs +++ b/matter/src/data_model/objects/node.rs @@ -21,13 +21,13 @@ use crate::{ interaction_model::{ core::IMStatusCode, messages::{ - ib::{AttrStatus, CmdStatus}, - msg::{InvReq, ReadReq, WriteReq}, + ib::{AttrPath, AttrStatus, CmdStatus, DataVersionFilter}, + msg::{InvReq, ReadReq, SubscribeReq, WriteReq}, GenericPath, }, }, // TODO: This layer shouldn't really depend on the TLV layer, should create an abstraction layer - tlv::TLVElement, + tlv::{TLVArray, TLVElement}, }; use core::{ fmt, @@ -72,15 +72,48 @@ impl<'a> Node<'a> { where 's: 'm, { - if let Some(attr_requests) = req.attr_requests.as_ref() { + self.read_attr_requests( + req.attr_requests.as_ref(), + req.dataver_filters.as_ref(), + req.fabric_filtered, + accessor, + ) + } + + pub fn subscribing_read<'s, 'm>( + &'s self, + req: &'m SubscribeReq, + accessor: &'m Accessor<'m>, + ) -> impl Iterator> + 'm + where + 's: 'm, + { + self.read_attr_requests( + req.attr_requests.as_ref(), + req.dataver_filters.as_ref(), + req.fabric_filtered, + accessor, + ) + } + + fn read_attr_requests<'s, 'm>( + &'s self, + attr_requests: Option<&'m TLVArray>, + dataver_filters: Option<&'m TLVArray>, + fabric_filtered: bool, + accessor: &'m Accessor<'m>, + ) -> impl Iterator> + 'm + where + 's: 'm, + { + if let Some(attr_requests) = attr_requests.as_ref() { WildcardIter::Wildcard(attr_requests.iter().flat_map( move |path| match self.expand_attr(accessor, path.to_gp(), false) { Ok(iter) => { let wildcard = matches!(iter, WildcardIter::Wildcard(_)); WildcardIter::Wildcard(iter.map(move |(ep, cl, attr)| { - let dataver_filter = req - .dataver_filters + let dataver_filter = dataver_filters .as_ref() .iter() .flat_map(|array| array.iter()) @@ -96,7 +129,7 @@ impl<'a> Node<'a> { attr_id: attr, list_index: path.list_index, fab_idx: accessor.fab_idx, - fab_filter: req.fabric_filtered, + fab_filter: fabric_filtered, dataver: dataver_filter, wildcard, }) diff --git a/matter/src/interaction_model/core.rs b/matter/src/interaction_model/core.rs index 8d8b4fb4..935740e6 100644 --- a/matter/src/interaction_model/core.rs +++ b/matter/src/interaction_model/core.rs @@ -28,7 +28,7 @@ use log::{error, info}; use num; use num_derive::FromPrimitive; -use super::messages::msg::{self, InvReq, ReadReq, StatusResp, TimedReq, WriteReq}; +use super::messages::msg::{self, InvReq, ReadReq, StatusResp, SubscribeReq, TimedReq, WriteReq}; #[macro_export] macro_rules! cmd_enter { @@ -104,7 +104,7 @@ pub enum OpCode { StatusResponse = 1, ReadRequest = 2, SubscribeRequest = 3, - SubscriptResponse = 4, + SubscribeResponse = 4, ReportData = 5, WriteRequest = 6, WriteResponse = 7, @@ -186,6 +186,8 @@ pub enum Interaction<'a> { Read(ReadReq<'a>), Write(WriteReq<'a>), Invoke(InvReq<'a>), + Subscribe(SubscribeReq<'a>), + Status(StatusResp), Timed(TimedReq), } @@ -209,12 +211,15 @@ impl<'a> Interaction<'a> { OpCode::InvokeRequest => Ok(Self::Invoke(InvReq::from_tlv(&get_root_node_struct( rx_data, )?)?)), + OpCode::SubscribeRequest => Ok(Self::Subscribe(SubscribeReq::from_tlv( + &get_root_node_struct(rx_data)?, + )?)), + OpCode::StatusResponse => Ok(Self::Status(StatusResp::from_tlv( + &get_root_node_struct(rx_data)?, + )?)), OpCode::TimedRequest => Ok(Self::Timed(TimedReq::from_tlv(&get_root_node_struct( rx_data, )?)?)), - // TODO - // OpCode::SubscribeRequest => self.handle_subscribe_req(&mut trans, buf, &mut ctx.tx)?, - // OpCode::StatusResponse => self.handle_status_resp(&mut trans, buf, &mut ctx.tx)?, _ => { error!("Opcode Not Handled: {:?}", opcode); Err(Error::InvalidOpcode) @@ -242,7 +247,7 @@ impl<'a> Interaction<'a> { false } - Interaction::Write(_) => { + Self::Write(_) => { if transaction.has_timed_out() { Self::create_status_response(tx, IMStatusCode::Timeout)?; @@ -262,7 +267,7 @@ impl<'a> Interaction<'a> { false } } - Interaction::Invoke(request) => { + Self::Invoke(request) => { if transaction.has_timed_out() { Self::create_status_response(tx, IMStatusCode::Timeout)?; @@ -303,7 +308,31 @@ impl<'a> Interaction<'a> { } } } - Interaction::Timed(request) => { + Self::Subscribe(request) => { + tx.set_proto_id(PROTO_ID_INTERACTION_MODEL as u16); + tx.set_proto_opcode(OpCode::ReportData as u8); + + let mut tw = TLVWriter::new(tx.get_writebuf()?); + + tw.start_struct(TagType::Anonymous)?; + + if request.attr_requests.is_some() { + tw.start_array(TagType::Context(msg::ReportDataTag::AttributeReports as u8))?; + } + + true + } + Self::Status(_) => { + tx.set_proto_id(PROTO_ID_INTERACTION_MODEL as u16); + tx.set_proto_opcode(OpCode::SubscribeResponse as u8); + + let mut tw = TLVWriter::new(tx.get_writebuf()?); + + tw.start_struct(TagType::Anonymous)?; + + true + } + Self::Timed(request) => { tx.set_proto_id(PROTO_ID_INTERACTION_MODEL as u16); tx.set_proto_opcode(OpCode::StatusResponse as u8); @@ -377,6 +406,24 @@ impl<'a> Interaction<'a> { true } + Self::Subscribe(request) => { + let mut tw = TLVWriter::new(tx.get_writebuf()?); + + if request.attr_requests.is_some() { + tw.end_container()?; + } + + tw.end_container()?; + + true + } + Self::Status(_) => { + let mut tw = TLVWriter::new(tx.get_writebuf()?); + + tw.end_container()?; + + true + } Self::Timed(_) => false, }; From fcc87bfaf435062c9a562e76b809e4d024767e2c Mon Sep 17 00:00:00 2001 From: ivmarkov Date: Sat, 22 Apr 2023 14:39:17 +0000 Subject: [PATCH 03/72] Long reads and subscriptions reintroduced --- matter/Cargo.toml | 2 +- matter/src/acl.rs | 15 +- matter/src/cert/mod.rs | 4 +- .../data_model/cluster_basic_information.rs | 14 +- matter/src/data_model/core.rs | 197 ++++-- matter/src/data_model/objects/cluster.rs | 78 +-- matter/src/data_model/objects/encoder.rs | 11 +- matter/src/data_model/objects/endpoint.rs | 36 +- matter/src/data_model/objects/handler.rs | 9 +- matter/src/data_model/objects/node.rs | 453 ++++++++----- matter/src/data_model/root_endpoint.rs | 44 +- .../data_model/sdm/general_commissioning.rs | 6 +- matter/src/data_model/sdm/noc.rs | 4 +- matter/src/data_model/sdm/nw_commissioning.rs | 12 +- .../data_model/system_model/access_control.rs | 17 +- matter/src/fabric.rs | 6 +- matter/src/interaction_model/core.rs | 625 ++++++++++++------ matter/src/secure_channel/case.rs | 6 +- matter/src/transport/exchange.rs | 45 +- matter/src/transport/packet.rs | 1 + matter/src/transport/proto_hdr.rs | 2 +- matter/src/transport/session.rs | 4 +- matter/src/utils/writebuf.rs | 78 ++- matter/tests/common/im_engine.rs | 30 +- matter/tests/data_model/long_reads.rs | 185 ++++-- matter/tests/data_model_tests.rs | 2 +- matter/tests/interaction_model.rs | 41 +- matter_macro_derive/src/lib.rs | 50 +- 28 files changed, 1305 insertions(+), 672 deletions(-) diff --git a/matter/Cargo.toml b/matter/Cargo.toml index 0987e10d..cf9497b4 100644 --- a/matter/Cargo.toml +++ b/matter/Cargo.toml @@ -15,7 +15,7 @@ name = "matter" path = "src/lib.rs" [features] -default = ["std", "crypto_mbedtls"] +default = ["std", "crypto_mbedtls", "nightly"] std = [] nightly = [] crypto_openssl = ["openssl", "foreign-types", "hmac", "sha2"] diff --git a/matter/src/acl.rs b/matter/src/acl.rs index 8f965e13..2bfd0d61 100644 --- a/matter/src/acl.rs +++ b/matter/src/acl.rs @@ -179,14 +179,14 @@ impl<'a> Accessor<'a> { let _ = subject.add_catid(i); } } - Accessor::new(c.fab_idx, subject, AuthMode::Case, &acl_mgr) + Accessor::new(c.fab_idx, subject, AuthMode::Case, acl_mgr) } SessionMode::Pase => { - Accessor::new(0, AccessorSubjects::new(1), AuthMode::Pase, &acl_mgr) + Accessor::new(0, AccessorSubjects::new(1), AuthMode::Pase, acl_mgr) } SessionMode::PlainText => { - Accessor::new(0, AccessorSubjects::new(1), AuthMode::Invalid, &acl_mgr) + Accessor::new(0, AccessorSubjects::new(1), AuthMode::Invalid, acl_mgr) } } } @@ -514,7 +514,7 @@ impl AclMgr { let mut wb = WriteBuf::new(&mut buf); let mut tw = TLVWriter::new(&mut wb); self.entries.to_tlv(&mut tw, TagType::Anonymous)?; - psm.set_kv_slice(ACL_KV_ENTRY, wb.into_slice())?; + psm.set_kv_slice(ACL_KV_ENTRY, wb.as_slice())?; self.changed = false; } @@ -546,7 +546,7 @@ impl AclMgr { let mut wb = WriteBuf::new(&mut buf); let mut tw = TLVWriter::new(&mut wb); self.entries.to_tlv(&mut tw, TagType::Anonymous)?; - psm.set_kv_slice(ACL_KV_ENTRY, wb.into_slice()).await?; + psm.set_kv_slice(ACL_KV_ENTRY, wb.as_slice()).await?; self.changed = false; } @@ -561,10 +561,7 @@ impl AclMgr { { let mut buf = [0u8; ACL_KV_MAX_SIZE]; let acl_tlvs = psm.get_kv_slice(ACL_KV_ENTRY, &mut buf).await?; - let root = TLVList::new(&acl_tlvs) - .iter() - .next() - .ok_or(Error::Invalid)?; + let root = TLVList::new(acl_tlvs).iter().next().ok_or(Error::Invalid)?; self.entries = AclEntries::from_tlv(&root)?; self.changed = false; diff --git a/matter/src/cert/mod.rs b/matter/src/cert/mod.rs index 664125b5..b9283299 100644 --- a/matter/src/cert/mod.rs +++ b/matter/src/cert/mod.rs @@ -597,7 +597,7 @@ impl Cert { let mut wb = WriteBuf::new(buf); let mut tw = TLVWriter::new(&mut wb); self.to_tlv(&mut tw, TagType::Anonymous)?; - Ok(wb.into_slice().len()) + Ok(wb.as_slice().len()) } pub fn as_asn1(&self, buf: &mut [u8]) -> Result { @@ -823,7 +823,7 @@ mod tests { let mut wb = WriteBuf::new(&mut buf); let mut tw = TLVWriter::new(&mut wb); cert.to_tlv(&mut tw, TagType::Anonymous).unwrap(); - assert_eq!(*input, wb.into_slice()); + assert_eq!(*input, wb.as_slice()); } } diff --git a/matter/src/data_model/cluster_basic_information.rs b/matter/src/data_model/cluster_basic_information.rs index 71c07229..dcc70282 100644 --- a/matter/src/data_model/cluster_basic_information.rs +++ b/matter/src/data_model/cluster_basic_information.rs @@ -19,11 +19,11 @@ use core::convert::TryInto; use super::objects::*; use crate::{attribute_enum, error::Error, utils::rand::Rand}; -use strum::{EnumDiscriminants, FromRepr}; +use strum::FromRepr; pub const ID: u32 = 0x0028; -#[derive(Clone, Copy, Debug, FromRepr, EnumDiscriminants)] +#[derive(Clone, Copy, Debug, FromRepr)] #[repr(u16)] pub enum Attributes { DMRevision(AttrType) = 0, @@ -37,6 +37,16 @@ pub enum Attributes { attribute_enum!(Attributes); +pub enum AttributesDiscriminants { + DMRevision = 0, + VendorId = 2, + ProductId = 4, + HwVer = 7, + SwVer = 9, + SwVerString = 0xa, + SerialNo = 0x0f, +} + #[derive(Default)] pub struct BasicInfoConfig<'a> { pub vid: u16, diff --git a/matter/src/data_model/core.rs b/matter/src/data_model/core.rs index 0ca83792..b53f955d 100644 --- a/matter/src/data_model/core.rs +++ b/matter/src/data_model/core.rs @@ -15,17 +15,26 @@ * limitations under the License. */ -use core::cell::RefCell; +use core::{ + cell::RefCell, + sync::atomic::{AtomicU32, Ordering}, +}; use super::objects::*; use crate::{ acl::{Accessor, AclMgr}, error::*, - interaction_model::core::{Interaction, Transaction}, - tlv::TLVWriter, + interaction_model::{ + core::{Interaction, Transaction}, + messages::msg::SubscribeResp, + }, + tlv::{TLVWriter, TagType, ToTLV}, transport::packet::Packet, }; +// TODO: For now... +static SUBS_ID: AtomicU32 = AtomicU32::new(1); + pub struct DataModel<'a, T> { pub acl_mgr: &'a RefCell, pub node: &'a Node<'a>, @@ -43,7 +52,7 @@ impl<'a, T> DataModel<'a, T> { pub fn handle( &mut self, - interaction: &Interaction, + interaction: Interaction, tx: &mut Packet, transaction: &mut Transaction, ) -> Result @@ -55,44 +64,89 @@ impl<'a, T> DataModel<'a, T> { match interaction { Interaction::Read(req) => { - for item in self.node.read(req, &accessor) { - AttrDataEncoder::handle_read(item, &self.handler, &mut tw)?; + let mut resume_path = None; + + for item in self.node.read(&req, &accessor) { + if let Some(path) = AttrDataEncoder::handle_read(item, &self.handler, &mut tw)? + { + resume_path = Some(path); + break; + } } + + req.complete(tx, transaction, resume_path) } Interaction::Write(req) => { - for item in self.node.write(req, &accessor) { + for item in self.node.write(&req, &accessor) { AttrDataEncoder::handle_write(item, &mut self.handler, &mut tw)?; } + + req.complete(tx, transaction) } Interaction::Invoke(req) => { - for item in self.node.invoke(req, &accessor) { + for item in self.node.invoke(&req, &accessor) { CmdDataEncoder::handle(item, &mut self.handler, transaction, &mut tw)?; } + + req.complete(tx, transaction) } Interaction::Subscribe(req) => { - for item in self.node.subscribing_read(req, &accessor) { - AttrDataEncoder::handle_read(item, &self.handler, &mut tw)?; + let mut resume_path = None; + + for item in self.node.subscribing_read(&req, &accessor) { + if let Some(path) = AttrDataEncoder::handle_read(item, &self.handler, &mut tw)? + { + resume_path = Some(path); + break; + } + } + + req.complete(tx, transaction, resume_path) + } + Interaction::Timed(_) => Ok(false), + Interaction::ResumeRead(req) => { + let mut resume_path = None; + + for item in self.node.resume_read(&req, &accessor) { + if let Some(path) = AttrDataEncoder::handle_read(item, &self.handler, &mut tw)? + { + resume_path = Some(path); + break; + } } + + req.complete(tx, transaction, resume_path) } - Interaction::Status(_resp) => { - todo!() - // for item in self.node.subscribing_read(req, &accessor) { - // AttrDataEncoder::handle_read(item, &self.handler, &mut tw)?; - // } + Interaction::ResumeSubscribe(req) => { + let mut resume_path = None; + + if req.resume_path.is_some() { + for item in self.node.resume_subscribing_read(&req, &accessor) { + if let Some(path) = + AttrDataEncoder::handle_read(item, &self.handler, &mut tw)? + { + resume_path = Some(path); + break; + } + } + } else { + // TODO + let resp = SubscribeResp::new(SUBS_ID.fetch_add(1, Ordering::SeqCst), 40); + resp.to_tlv(&mut tw, TagType::Anonymous)?; + } + + req.complete(tx, transaction, resume_path) } - Interaction::Timed(_) => (), } - - interaction.complete_tx(tx, transaction) } #[cfg(feature = "nightly")] pub async fn handle_async<'p>( &mut self, - interaction: &Interaction<'_>, + interaction: Interaction<'_>, tx: &'p mut Packet<'_>, transaction: &mut Transaction<'_, '_>, - ) -> Result, Error> + ) -> Result where T: super::objects::asynch::AsyncHandler, { @@ -101,32 +155,91 @@ impl<'a, T> DataModel<'a, T> { match interaction { Interaction::Read(req) => { - for item in self.node.read(req, &accessor) { - AttrDataEncoder::handle_read_async(item, &self.handler, &mut tw).await?; + let mut resume_path = None; + + for item in self.node.read(&req, &accessor) { + if let Some(path) = + AttrDataEncoder::handle_read_async(item, &self.handler, &mut tw).await? + { + resume_path = Some(path); + break; + } } + + req.complete(tx, transaction, resume_path) } Interaction::Write(req) => { - for item in self.node.write(req, &accessor) { + for item in self.node.write(&req, &accessor) { AttrDataEncoder::handle_write_async(item, &mut self.handler, &mut tw).await?; } + + req.complete(tx, transaction) } Interaction::Invoke(req) => { - for item in self.node.invoke(req, &accessor) { + for item in self.node.invoke(&req, &accessor) { CmdDataEncoder::handle_async(item, &mut self.handler, transaction, &mut tw) .await?; } + + req.complete(tx, transaction) } - Interaction::Timed(_) => (), - } + Interaction::Subscribe(req) => { + let mut resume_path = None; + + for item in self.node.subscribing_read(&req, &accessor) { + if let Some(path) = + AttrDataEncoder::handle_read_async(item, &self.handler, &mut tw).await? + { + resume_path = Some(path); + break; + } + } + + req.complete(tx, transaction, resume_path) + } + Interaction::Timed(_) => Ok(false), + Interaction::ResumeRead(req) => { + let mut resume_path = None; - interaction.complete_tx(tx, transaction) + for item in self.node.resume_read(&req, &accessor) { + if let Some(path) = + AttrDataEncoder::handle_read_async(item, &self.handler, &mut tw).await? + { + resume_path = Some(path); + break; + } + } + + req.complete(tx, transaction, resume_path) + } + Interaction::ResumeSubscribe(req) => { + let mut resume_path = None; + + if req.resume_path.is_some() { + for item in self.node.resume_subscribing_read(&req, &accessor) { + if let Some(path) = + AttrDataEncoder::handle_read_async(item, &self.handler, &mut tw).await? + { + resume_path = Some(path); + break; + } + } + } else { + // TODO + let resp = SubscribeResp::new(SUBS_ID.fetch_add(1, Ordering::SeqCst), 40); + resp.to_tlv(&mut tw, TagType::Anonymous)?; + } + + req.complete(tx, transaction, resume_path) + } + } } } pub trait DataHandler { fn handle( &mut self, - interaction: &Interaction, + interaction: Interaction, tx: &mut Packet, transaction: &mut Transaction, ) -> Result; @@ -138,7 +251,7 @@ where { fn handle( &mut self, - interaction: &Interaction, + interaction: Interaction, tx: &mut Packet, transaction: &mut Transaction, ) -> Result { @@ -152,7 +265,7 @@ where { fn handle( &mut self, - interaction: &Interaction, + interaction: Interaction, tx: &mut Packet, transaction: &mut Transaction, ) -> Result { @@ -172,24 +285,24 @@ pub mod asynch { use super::DataModel; pub trait AsyncDataHandler { - async fn handle<'p>( + async fn handle( &mut self, - interaction: &Interaction, - tx: &'p mut Packet, + interaction: Interaction<'_>, + tx: &mut Packet, transaction: &mut Transaction, - ) -> Result, Error>; + ) -> Result; } impl AsyncDataHandler for &mut T where T: AsyncDataHandler, { - async fn handle<'p>( + async fn handle( &mut self, - interaction: &Interaction<'_>, - tx: &'p mut Packet<'_>, + interaction: Interaction<'_>, + tx: &mut Packet<'_>, transaction: &mut Transaction<'_, '_>, - ) -> Result, Error> { + ) -> Result { (**self).handle(interaction, tx, transaction).await } } @@ -198,12 +311,12 @@ pub mod asynch { where T: AsyncHandler, { - async fn handle<'p>( + async fn handle( &mut self, - interaction: &Interaction<'_>, - tx: &'p mut Packet<'_>, + interaction: Interaction<'_>, + tx: &mut Packet<'_>, transaction: &mut Transaction<'_, '_>, - ) -> Result, Error> { + ) -> Result { DataModel::handle_async(self, interaction, tx, transaction).await } } diff --git a/matter/src/data_model/objects/cluster.rs b/matter/src/data_model/objects/cluster.rs index 90c6835d..3818f93b 100644 --- a/matter/src/data_model/objects/cluster.rs +++ b/matter/src/data_model/objects/cluster.rs @@ -64,6 +64,7 @@ pub const ATTRIBUTE_LIST: Attribute = Attribute::new( // TODO: What if we instead of creating this, we just pass the AttrData/AttrPath to the read/write // methods? /// The Attribute Details structure records the details about the attribute under consideration. +#[derive(Debug)] pub struct AttrDetails<'a> { pub node: &'a Node<'a>, /// The actual endpoint ID @@ -129,6 +130,7 @@ impl<'a> AttrDetails<'a> { } } +#[derive(Debug)] pub struct CmdDetails<'a> { pub node: &'a Node<'a>, pub endpoint_id: EndptId, @@ -208,49 +210,23 @@ impl<'a> Cluster<'a> { } } - pub(crate) fn match_attributes<'m>( - &'m self, - accessor: &'m Accessor<'m>, - ep: EndptId, + pub fn match_attributes( + &self, attr: Option, - write: bool, - ) -> impl Iterator + 'm { + ) -> impl Iterator + '_ { self.attributes .iter() .filter(move |attribute| attr.map(|attr| attr == attribute.id).unwrap_or(true)) - .filter(move |attribute| { - let mut access_req = AccessReq::new( - accessor, - GenericPath::new(Some(ep), Some(self.id), Some(attribute.id as _)), - if write { Access::WRITE } else { Access::READ }, - ); - self.check_attr_access(&mut access_req, attribute.access) - .is_ok() - }) - .map(|attribute| attribute.id) } - pub fn match_commands<'m>( - &'m self, - accessor: &'m Accessor<'m>, - ep: EndptId, - cmd: Option, - ) -> impl Iterator + 'm { + pub fn match_commands(&self, cmd: Option) -> impl Iterator + '_ { self.commands .iter() .filter(move |id| cmd.map(|cmd| **id == cmd).unwrap_or(true)) - .filter(move |id| { - let mut access_req = AccessReq::new( - accessor, - GenericPath::new(Some(ep), Some(self.id), Some(**id as _)), - Access::WRITE, - ); - self.check_cmd_access(&mut access_req).is_ok() - }) .copied() } - pub(crate) fn check_attribute( + pub fn check_attribute( &self, accessor: &Accessor, ep: EndptId, @@ -263,16 +239,15 @@ impl<'a> Cluster<'a> { .find(|attribute| attribute.id == attr) .ok_or(IMStatusCode::UnsupportedAttribute)?; - let mut access_req = AccessReq::new( + Self::check_attr_access( accessor, GenericPath::new(Some(ep), Some(self.id), Some(attr as _)), - if write { Access::WRITE } else { Access::READ }, - ); - - self.check_attr_access(&mut access_req, attribute.access) + write, + attribute.access, + ) } - pub(crate) fn check_command( + pub fn check_command( &self, accessor: &Accessor, ep: EndptId, @@ -283,20 +258,24 @@ impl<'a> Cluster<'a> { .find(|id| **id == cmd) .ok_or(IMStatusCode::UnsupportedCommand)?; - let mut access_req = AccessReq::new( + Self::check_cmd_access( accessor, - GenericPath::new(Some(ep), Some(self.id), Some(cmd as _)), - Access::WRITE, - ); - - self.check_cmd_access(&mut access_req) + GenericPath::new(Some(ep), Some(self.id), Some(cmd)), + ) } - fn check_attr_access( - &self, - access_req: &mut AccessReq, + pub(crate) fn check_attr_access( + accessor: &Accessor, + path: GenericPath, + write: bool, target_perms: Access, ) -> Result<(), IMStatusCode> { + let mut access_req = AccessReq::new( + accessor, + path, + if write { Access::WRITE } else { Access::READ }, + ); + if !target_perms.contains(access_req.operation()) { Err(if matches!(access_req.operation(), Access::WRITE) { IMStatusCode::UnsupportedWrite @@ -313,7 +292,12 @@ impl<'a> Cluster<'a> { } } - fn check_cmd_access(&self, access_req: &mut AccessReq) -> Result<(), IMStatusCode> { + pub(crate) fn check_cmd_access( + accessor: &Accessor, + path: GenericPath, + ) -> Result<(), IMStatusCode> { + let mut access_req = AccessReq::new(accessor, path, Access::WRITE); + access_req.set_target_perms( Access::WRITE .union(Access::NEED_OPERATE) diff --git a/matter/src/data_model/objects/encoder.rs b/matter/src/data_model/objects/encoder.rs index 39d2ba6d..b4066e6b 100644 --- a/matter/src/data_model/objects/encoder.rs +++ b/matter/src/data_model/objects/encoder.rs @@ -23,6 +23,7 @@ use crate::interaction_model::core::{IMStatusCode, Transaction}; use crate::interaction_model::messages::ib::{ AttrPath, AttrResp, AttrStatus, CmdDataTag, CmdPath, CmdStatus, InvResp, InvRespTag, }; +use crate::interaction_model::messages::GenericPath; use crate::tlv::UtfStr; use crate::{ error::Error, @@ -127,13 +128,14 @@ impl<'a, 'b, 'c> AttrDataEncoder<'a, 'b, 'c> { item: Result, handler: &T, tw: &mut TLVWriter, - ) -> Result<(), Error> { + ) -> Result, Error> { let status = match item { Ok(attr) => { let encoder = AttrDataEncoder::new(&attr, tw); match handler.read(&attr, encoder) { Ok(()) => None, + Err(Error::NoSpace) => return Ok(Some(attr.path().to_gp())), Err(error) => attr.status(error.into())?, } } @@ -144,7 +146,7 @@ impl<'a, 'b, 'c> AttrDataEncoder<'a, 'b, 'c> { AttrResp::Status(status).to_tlv(tw, TagType::Anonymous)?; } - Ok(()) + Ok(None) } pub fn handle_write( @@ -172,13 +174,14 @@ impl<'a, 'b, 'c> AttrDataEncoder<'a, 'b, 'c> { item: Result, AttrStatus>, handler: &T, tw: &mut TLVWriter<'_, '_>, - ) -> Result<(), Error> { + ) -> Result, Error> { let status = match item { Ok(attr) => { let encoder = AttrDataEncoder::new(&attr, tw); match handler.read(&attr, encoder).await { Ok(()) => None, + Err(Error::NoSpace) => return Ok(Some(attr.path().to_gp())), Err(error) => attr.status(error.into())?, } } @@ -189,7 +192,7 @@ impl<'a, 'b, 'c> AttrDataEncoder<'a, 'b, 'c> { AttrResp::Status(status).to_tlv(tw, TagType::Anonymous)?; } - Ok(()) + Ok(None) } #[cfg(feature = "nightly")] diff --git a/matter/src/data_model/objects/endpoint.rs b/matter/src/data_model/objects/endpoint.rs index d0a4fddf..05878edc 100644 --- a/matter/src/data_model/objects/endpoint.rs +++ b/matter/src/data_model/objects/endpoint.rs @@ -19,7 +19,7 @@ use crate::{acl::Accessor, interaction_model::core::IMStatusCode}; use core::fmt; -use super::{AttrId, Cluster, ClusterId, CmdId, DeviceType, EndptId}; +use super::{AttrId, Attribute, Cluster, ClusterId, CmdId, DeviceType, EndptId}; #[derive(Debug, Clone)] pub struct Endpoint<'a> { @@ -29,34 +29,28 @@ pub struct Endpoint<'a> { } impl<'a> Endpoint<'a> { - pub(crate) fn match_attributes<'m>( - &'m self, - accessor: &'m Accessor<'m>, + pub fn match_attributes( + &self, cl: Option, attr: Option, - write: bool, - ) -> impl Iterator + 'm { + ) -> impl Iterator + '_ { self.match_clusters(cl).flat_map(move |cluster| { cluster - .match_attributes(accessor, self.id, attr, write) - .map(move |attr| (cluster.id, attr)) + .match_attributes(attr) + .map(move |attr| (cluster, attr)) }) } - pub(crate) fn match_commands<'m>( - &'m self, - accessor: &'m Accessor<'m>, + pub fn match_commands( + &self, cl: Option, cmd: Option, - ) -> impl Iterator + 'm { - self.match_clusters(cl).flat_map(move |cluster| { - cluster - .match_commands(accessor, self.id, cmd) - .map(move |cmd| (cluster.id, cmd)) - }) + ) -> impl Iterator + '_ { + self.match_clusters(cl) + .flat_map(move |cluster| cluster.match_commands(cmd).map(move |cmd| (cluster, cmd))) } - pub(crate) fn check_attribute( + pub fn check_attribute( &self, accessor: &Accessor, cl: ClusterId, @@ -67,7 +61,7 @@ impl<'a> Endpoint<'a> { .and_then(|cluster| cluster.check_attribute(accessor, self.id, attr, write)) } - pub(crate) fn check_command( + pub fn check_command( &self, accessor: &Accessor, cl: ClusterId, @@ -77,13 +71,13 @@ impl<'a> Endpoint<'a> { .and_then(|cluster| cluster.check_command(accessor, self.id, cmd)) } - fn match_clusters(&self, cl: Option) -> impl Iterator + '_ { + pub fn match_clusters(&self, cl: Option) -> impl Iterator + '_ { self.clusters .iter() .filter(move |cluster| cl.map(|id| id == cluster.id).unwrap_or(true)) } - fn check_cluster(&self, cl: ClusterId) -> Result<&Cluster, IMStatusCode> { + pub fn check_cluster(&self, cl: ClusterId) -> Result<&Cluster, IMStatusCode> { self.clusters .iter() .find(|cluster| cluster.id == cl) diff --git a/matter/src/data_model/objects/handler.rs b/matter/src/data_model/objects/handler.rs index 052d6906..7758427f 100644 --- a/matter/src/data_model/objects/handler.rs +++ b/matter/src/data_model/objects/handler.rs @@ -186,9 +186,16 @@ macro_rules! handler_chain_type { ($h:ty) => { $crate::data_model::objects::ChainedHandler<$h, $crate::data_model::objects::EmptyHandler> }; - ($h1:ty, $($rest:ty),+) => { + ($h1:ty $(, $rest:ty)+) => { $crate::data_model::objects::ChainedHandler<$h1, handler_chain_type!($($rest),+)> }; + + ($h:ty | $f:ty) => { + $crate::data_model::objects::ChainedHandler<$h, $f> + }; + ($h1:ty $(, $rest:ty)+ | $f:ty) => { + $crate::data_model::objects::ChainedHandler<$h1, handler_chain_type!($($rest),+ | $f)> + }; } #[cfg(feature = "nightly")] diff --git a/matter/src/data_model/objects/node.rs b/matter/src/data_model/objects/node.rs index 4ec1765c..3ee3af27 100644 --- a/matter/src/data_model/objects/node.rs +++ b/matter/src/data_model/objects/node.rs @@ -19,7 +19,7 @@ use crate::{ acl::Accessor, data_model::objects::Endpoint, interaction_model::{ - core::IMStatusCode, + core::{IMStatusCode, ResumeReadReq, ResumeSubscribeReq}, messages::{ ib::{AttrPath, AttrStatus, CmdStatus, DataVersionFilter}, msg::{InvReq, ReadReq, SubscribeReq, WriteReq}, @@ -27,16 +27,16 @@ use crate::{ }, }, // TODO: This layer shouldn't really depend on the TLV layer, should create an abstraction layer - tlv::{TLVArray, TLVElement}, + tlv::{TLVArray, TLVArrayIter, TLVElement}, }; use core::{ fmt, iter::{once, Once}, }; -use super::{AttrDetails, AttrId, ClusterId, CmdDetails, CmdId, EndptId}; +use super::{AttrDetails, AttrId, Attribute, Cluster, ClusterId, CmdDetails, CmdId, EndptId}; -enum WildcardIter { +pub enum WildcardIter { None, Single(Once), Wildcard(T), @@ -57,6 +57,41 @@ where } } +pub trait Iterable { + type Item; + + type Iterator<'a>: Iterator + where + Self: 'a; + + fn iter(&self) -> Self::Iterator<'_>; +} + +impl<'a> Iterable for Option<&'a TLVArray<'a, DataVersionFilter>> { + type Item = DataVersionFilter; + + type Iterator<'i> = WildcardIter, DataVersionFilter> where Self: 'i; + + fn iter(&self) -> Self::Iterator<'_> { + if let Some(filters) = self { + WildcardIter::Wildcard(filters.iter()) + } else { + WildcardIter::None + } + } +} + +impl<'a> Iterable for &'a [DataVersionFilter] { + type Item = DataVersionFilter; + + type Iterator<'i> = core::iter::Copied> where Self: 'i; + + fn iter(&self) -> Self::Iterator<'_> { + let slice: &[DataVersionFilter] = self; + slice.iter().copied() + } +} + #[derive(Debug, Clone)] pub struct Node<'a> { pub id: u16, @@ -73,10 +108,30 @@ impl<'a> Node<'a> { 's: 'm, { self.read_attr_requests( - req.attr_requests.as_ref(), + req.attr_requests + .iter() + .flat_map(|attr_requests| attr_requests.iter()), req.dataver_filters.as_ref(), req.fabric_filtered, accessor, + None, + ) + } + + pub fn resume_read<'s, 'm>( + &'s self, + req: &'m ResumeReadReq, + accessor: &'m Accessor<'m>, + ) -> impl Iterator> + 'm + where + 's: 'm, + { + self.read_attr_requests( + req.paths.iter().copied(), + req.filters.as_slice(), + req.fabric_filtered, + accessor, + Some(req.resume_path), ) } @@ -89,60 +144,115 @@ impl<'a> Node<'a> { 's: 'm, { self.read_attr_requests( - req.attr_requests.as_ref(), + req.attr_requests + .iter() + .flat_map(|attr_requests| attr_requests.iter()), req.dataver_filters.as_ref(), req.fabric_filtered, accessor, + None, + ) + } + + pub fn resume_subscribing_read<'s, 'm>( + &'s self, + req: &'m ResumeSubscribeReq, + accessor: &'m Accessor<'m>, + ) -> impl Iterator> + 'm + where + 's: 'm, + { + self.read_attr_requests( + req.paths.iter().copied(), + req.filters.as_slice(), + req.fabric_filtered, + accessor, + Some(req.resume_path.unwrap()), ) } - fn read_attr_requests<'s, 'm>( + fn read_attr_requests<'s, 'm, P, D>( &'s self, - attr_requests: Option<&'m TLVArray>, - dataver_filters: Option<&'m TLVArray>, + attr_requests: P, + dataver_filters: D, fabric_filtered: bool, accessor: &'m Accessor<'m>, + from: Option, ) -> impl Iterator> + 'm where 's: 'm, + P: Iterator + 'm, + D: Iterable + Clone + 'm, { - if let Some(attr_requests) = attr_requests.as_ref() { - WildcardIter::Wildcard(attr_requests.iter().flat_map( - move |path| match self.expand_attr(accessor, path.to_gp(), false) { - Ok(iter) => { - let wildcard = matches!(iter, WildcardIter::Wildcard(_)); - - WildcardIter::Wildcard(iter.map(move |(ep, cl, attr)| { - let dataver_filter = dataver_filters - .as_ref() - .iter() - .flat_map(|array| array.iter()) - .find_map(|filter| { - (filter.path.endpoint == ep && filter.path.cluster == cl) - .then_some(filter.data_ver) - }); - - Ok(AttrDetails { - node: self, - endpoint_id: ep, - cluster_id: cl, - attr_id: attr, - list_index: path.list_index, - fab_idx: accessor.fab_idx, - fab_filter: fabric_filtered, - dataver: dataver_filter, - wildcard, - }) - })) - } - Err(err) => { - WildcardIter::Single(once(Err(AttrStatus::new(&path.to_gp(), err, 0)))) + attr_requests.flat_map(move |path| { + if path.to_gp().is_wildcard() { + let dataver_filters = dataver_filters.clone(); + let from = from; + + let iter = self + .match_attributes(path.endpoint, path.cluster, path.attr) + .skip_while(move |(ep, cl, attr)| { + !Self::matches(from.as_ref(), ep.id, cl.id, attr.id as _) + }) + .filter(move |(ep, cl, attr)| { + Cluster::check_attr_access( + accessor, + GenericPath::new(Some(ep.id), Some(cl.id), Some(attr.id as _)), + false, + attr.access, + ) + .is_ok() + }) + .map(move |(ep, cl, attr)| { + let dataver = dataver_filters.iter().find_map(|filter| { + (filter.path.endpoint == ep.id && filter.path.cluster == cl.id) + .then_some(filter.data_ver) + }); + + Ok(AttrDetails { + node: self, + endpoint_id: ep.id, + cluster_id: cl.id, + attr_id: attr.id, + list_index: path.list_index, + fab_idx: accessor.fab_idx, + fab_filter: fabric_filtered, + dataver, + wildcard: true, + }) + }); + + WildcardIter::Wildcard(iter) + } else { + let ep = path.endpoint.unwrap(); + let cl = path.cluster.unwrap(); + let attr = path.attr.unwrap(); + + let result = match self.check_attribute(accessor, ep, cl, attr, false) { + Ok(()) => { + let dataver = dataver_filters.iter().find_map(|filter| { + (filter.path.endpoint == ep && filter.path.cluster == cl) + .then_some(filter.data_ver) + }); + + Ok(AttrDetails { + node: self, + endpoint_id: ep, + cluster_id: cl, + attr_id: attr, + list_index: path.list_index, + fab_idx: accessor.fab_idx, + fab_filter: fabric_filtered, + dataver, + wildcard: false, + }) } - }, - )) - } else { - WildcardIter::None - } + Err(err) => Err(AttrStatus::new(&path.to_gp(), err, 0)), + }; + + WildcardIter::Single(once(result)) + } + }) } pub fn write<'m>( @@ -163,34 +273,64 @@ impl<'a> Node<'a> { IMStatusCode::UnsupportedAttribute, 0, )))) + } else if attr_data.path.to_gp().is_wildcard() { + let iter = self + .match_attributes( + attr_data.path.endpoint, + attr_data.path.cluster, + attr_data.path.attr, + ) + .filter(move |(ep, cl, attr)| { + Cluster::check_attr_access( + accessor, + GenericPath::new(Some(ep.id), Some(cl.id), Some(attr.id as _)), + true, + attr.access, + ) + .is_ok() + }) + .map(move |(ep, cl, attr)| { + Ok(( + AttrDetails { + node: self, + endpoint_id: ep.id, + cluster_id: cl.id, + attr_id: attr.id, + list_index: attr_data.path.list_index, + fab_idx: accessor.fab_idx, + fab_filter: false, + dataver: attr_data.data_ver, + wildcard: true, + }, + attr_data.data.unwrap_tlv().unwrap(), + )) + }); + + WildcardIter::Wildcard(iter) } else { - match self.expand_attr(accessor, attr_data.path.to_gp(), true) { - Ok(iter) => { - let wildcard = matches!(iter, WildcardIter::Wildcard(_)); - - WildcardIter::Wildcard(iter.map(move |(ep, cl, attr)| { - Ok(( - AttrDetails { - node: self, - endpoint_id: ep, - cluster_id: cl, - attr_id: attr, - list_index: attr_data.path.list_index, - fab_idx: accessor.fab_idx, - fab_filter: false, - dataver: attr_data.data_ver, - wildcard, - }, - attr_data.data.unwrap_tlv().unwrap(), - )) - })) - } - Err(err) => WildcardIter::Single(once(Err(AttrStatus::new( - &attr_data.path.to_gp(), - err, - 0, - )))), - } + let ep = attr_data.path.endpoint.unwrap(); + let cl = attr_data.path.cluster.unwrap(); + let attr = attr_data.path.attr.unwrap(); + + let result = match self.check_attribute(accessor, ep, cl, attr, true) { + Ok(()) => Ok(( + AttrDetails { + node: self, + endpoint_id: ep, + cluster_id: cl, + attr_id: attr, + list_index: attr_data.path.list_index, + fab_idx: accessor.fab_idx, + fab_filter: false, + dataver: attr_data.data_ver, + wildcard: false, + }, + attr_data.data.unwrap_tlv().unwrap(), + )), + Err(err) => Err(AttrStatus::new(&attr_data.path.to_gp(), err, 0)), + }; + + WildcardIter::Single(once(result)) } }) } @@ -200,136 +340,99 @@ impl<'a> Node<'a> { req: &'m InvReq, accessor: &'m Accessor<'m>, ) -> impl Iterator), CmdStatus>> + 'm { - if let Some(inv_requests) = req.inv_requests.as_ref() { - WildcardIter::Wildcard(inv_requests.iter().flat_map(move |cmd_data| { - match self.expand_cmd(accessor, cmd_data.path.path) { - Ok(iter) => { - let wildcard = matches!(iter, WildcardIter::Wildcard(_)); - - WildcardIter::Wildcard(iter.map(move |(ep, cl, cmd)| { + req.inv_requests + .iter() + .flat_map(|inv_requests| inv_requests.iter()) + .flat_map(move |cmd_data| { + if cmd_data.path.path.is_wildcard() { + let iter = self + .match_commands( + cmd_data.path.path.endpoint, + cmd_data.path.path.cluster, + cmd_data.path.path.leaf.map(|leaf| leaf as _), + ) + .filter(move |(ep, cl, cmd)| { + Cluster::check_cmd_access( + accessor, + GenericPath::new(Some(ep.id), Some(cl.id), Some(*cmd)), + ) + .is_ok() + }) + .map(move |(ep, cl, cmd)| { Ok(( CmdDetails { node: self, - endpoint_id: ep, - cluster_id: cl, + endpoint_id: ep.id, + cluster_id: cl.id, cmd_id: cmd, - wildcard, + wildcard: true, }, cmd_data.data.unwrap_tlv().unwrap(), )) - })) - } - Err(err) => { - WildcardIter::Single(once(Err(CmdStatus::new(cmd_data.path, err, 0)))) - } - } - })) - } else { - WildcardIter::None - } - } + }); - fn expand_attr<'m>( - &'m self, - accessor: &'m Accessor<'m>, - path: GenericPath, - write: bool, - ) -> Result< - WildcardIter< - impl Iterator + 'm, - (EndptId, ClusterId, AttrId), - >, - IMStatusCode, - > { - if path.is_wildcard() { - Ok(WildcardIter::Wildcard(self.match_attributes( - accessor, - path.endpoint, - path.cluster, - path.leaf.map(|leaf| leaf as u16), - write, - ))) - } else { - self.check_attribute( - accessor, - path.endpoint.unwrap(), - path.cluster.unwrap(), - path.leaf.unwrap() as _, - write, - )?; - - Ok(WildcardIter::Single(once(( - path.endpoint.unwrap(), - path.cluster.unwrap(), - path.leaf.unwrap() as _, - )))) - } + WildcardIter::Wildcard(iter) + } else { + let ep = cmd_data.path.path.endpoint.unwrap(); + let cl = cmd_data.path.path.cluster.unwrap(); + let cmd = cmd_data.path.path.leaf.unwrap(); + + let result = match self.check_command(accessor, ep, cl, cmd) { + Ok(()) => Ok(( + CmdDetails { + node: self, + endpoint_id: cmd_data.path.path.endpoint.unwrap(), + cluster_id: cmd_data.path.path.cluster.unwrap(), + cmd_id: cmd_data.path.path.leaf.unwrap(), + wildcard: false, + }, + cmd_data.data.unwrap_tlv().unwrap(), + )), + Err(err) => Err(CmdStatus::new(cmd_data.path, err, 0)), + }; + + WildcardIter::Single(once(result)) + } + }) } - fn expand_cmd<'m>( - &'m self, - accessor: &'m Accessor<'m>, - path: GenericPath, - ) -> Result< - WildcardIter< - impl Iterator + 'm, - (EndptId, ClusterId, CmdId), - >, - IMStatusCode, - > { - if path.is_wildcard() { - Ok(WildcardIter::Wildcard(self.match_commands( - accessor, - path.endpoint, - path.cluster, - path.leaf, - ))) + fn matches(path: Option<&GenericPath>, ep: EndptId, cl: ClusterId, leaf: u32) -> bool { + if let Some(path) = path { + path.endpoint.map(|id| id == ep).unwrap_or(true) + && path.cluster.map(|id| id == cl).unwrap_or(true) + && path.leaf.map(|id| id == leaf).unwrap_or(true) } else { - self.check_command( - accessor, - path.endpoint.unwrap(), - path.cluster.unwrap(), - path.leaf.unwrap(), - )?; - - Ok(WildcardIter::Single(once(( - path.endpoint.unwrap(), - path.cluster.unwrap(), - path.leaf.unwrap(), - )))) + true } } - fn match_attributes<'m>( - &'m self, - accessor: &'m Accessor<'m>, + pub fn match_attributes( + &self, ep: Option, cl: Option, attr: Option, - write: bool, - ) -> impl Iterator + 'm { + ) -> impl Iterator + '_ { self.match_endpoints(ep).flat_map(move |endpoint| { endpoint - .match_attributes(accessor, cl, attr, write) - .map(move |(cl, attr)| (endpoint.id, cl, attr)) + .match_attributes(cl, attr) + .map(move |(cl, attr)| (endpoint, cl, attr)) }) } - fn match_commands<'m>( - &'m self, - accessor: &'m Accessor<'m>, + pub fn match_commands( + &self, ep: Option, cl: Option, cmd: Option, - ) -> impl Iterator + 'm { + ) -> impl Iterator + '_ { self.match_endpoints(ep).flat_map(move |endpoint| { endpoint - .match_commands(accessor, cl, cmd) - .map(move |(cl, cmd)| (endpoint.id, cl, cmd)) + .match_commands(cl, cmd) + .map(move |(cl, cmd)| (endpoint, cl, cmd)) }) } - fn check_attribute( + pub fn check_attribute( &self, accessor: &Accessor, ep: EndptId, @@ -341,7 +444,7 @@ impl<'a> Node<'a> { .and_then(|endpoint| endpoint.check_attribute(accessor, cl, attr, write)) } - fn check_command( + pub fn check_command( &self, accessor: &Accessor, ep: EndptId, @@ -352,13 +455,13 @@ impl<'a> Node<'a> { .and_then(|endpoint| endpoint.check_command(accessor, cl, cmd)) } - fn match_endpoints(&self, ep: Option) -> impl Iterator + '_ { + pub fn match_endpoints(&self, ep: Option) -> impl Iterator + '_ { self.endpoints .iter() .filter(move |endpoint| ep.map(|id| id == endpoint.id).unwrap_or(true)) } - fn check_endpoint(&self, ep: EndptId) -> Result<&Endpoint, IMStatusCode> { + pub fn check_endpoint(&self, ep: EndptId) -> Result<&Endpoint, IMStatusCode> { self.endpoints .iter() .find(|endpoint| endpoint.id == ep) diff --git a/matter/src/data_model/root_endpoint.rs b/matter/src/data_model/root_endpoint.rs index 44131b9f..ebcbb140 100644 --- a/matter/src/data_model/root_endpoint.rs +++ b/matter/src/data_model/root_endpoint.rs @@ -21,19 +21,24 @@ use super::{ noc::{self, NocCluster}, nw_commissioning::{self, NwCommCluster}, }, - system_model::access_control::{self, AccessControlCluster}, + system_model::{ + access_control::{self, AccessControlCluster}, + descriptor::{self, DescriptorCluster}, + }, }; pub type RootEndpointHandler<'a> = handler_chain_type!( - AccessControlCluster<'a>, - NocCluster<'a>, - AdminCommCluster<'a>, - NwCommCluster, + DescriptorCluster, + BasicInfoCluster<'a>, GenCommCluster, - BasicInfoCluster<'a> + NwCommCluster, + AdminCommCluster<'a>, + NocCluster<'a>, + AccessControlCluster<'a> ); -pub const CLUSTERS: [Cluster<'static>; 6] = [ +pub const CLUSTERS: [Cluster<'static>; 7] = [ + descriptor::CLUSTER, cluster_basic_information::CLUSTER, general_commissioning::CLUSTER, nw_commissioning::CLUSTER, @@ -77,32 +82,29 @@ pub fn wrap<'a>( EmptyHandler .chain( endpoint_id, - cluster_basic_information::CLUSTER.id, - BasicInfoCluster::new(basic_info, rand), - ) - .chain( - endpoint_id, - general_commissioning::CLUSTER.id, - GenCommCluster::new(rand), + access_control::ID, + AccessControlCluster::new(acl, rand), ) .chain( endpoint_id, - nw_commissioning::CLUSTER.id, - NwCommCluster::new(rand), + noc::ID, + NocCluster::new(dev_att, fabric, acl, failsafe, mdns_mgr, epoch, rand), ) .chain( endpoint_id, - admin_commissioning::CLUSTER.id, + admin_commissioning::ID, AdminCommCluster::new(pase, mdns_mgr, rand), ) + .chain(endpoint_id, nw_commissioning::ID, NwCommCluster::new(rand)) .chain( endpoint_id, - noc::CLUSTER.id, - NocCluster::new(dev_att, fabric, acl, failsafe, mdns_mgr, epoch, rand), + general_commissioning::ID, + GenCommCluster::new(rand), ) .chain( endpoint_id, - access_control::CLUSTER.id, - AccessControlCluster::new(acl, rand), + cluster_basic_information::ID, + BasicInfoCluster::new(basic_info, rand), ) + .chain(endpoint_id, descriptor::ID, DescriptorCluster::new(rand)) } diff --git a/matter/src/data_model/sdm/general_commissioning.rs b/matter/src/data_model/sdm/general_commissioning.rs index aea37c7a..0c007d1b 100644 --- a/matter/src/data_model/sdm/general_commissioning.rs +++ b/matter/src/data_model/sdm/general_commissioning.rs @@ -215,7 +215,7 @@ impl GenCommCluster { encoder .with_command(RespCommands::ArmFailsafeResp as _)? - .set(&cmd_data) + .set(cmd_data) } fn handle_command_setregulatoryconfig( @@ -238,7 +238,7 @@ impl GenCommCluster { encoder .with_command(RespCommands::SetRegulatoryConfigResp as _)? - .set(&cmd_data) + .set(cmd_data) } fn handle_command_commissioningcomplete( @@ -272,7 +272,7 @@ impl GenCommCluster { encoder .with_command(RespCommands::CommissioningCompleteResp as _)? - .set(&cmd_data) + .set(cmd_data) } } diff --git a/matter/src/data_model/sdm/noc.rs b/matter/src/data_model/sdm/noc.rs index 0258f3a5..acaea504 100644 --- a/matter/src/data_model/sdm/noc.rs +++ b/matter/src/data_model/sdm/noc.rs @@ -398,7 +398,7 @@ impl<'a> NocCluster<'a> { encoder .with_command(RespCommands::NOCResp as _)? - .set(&cmd_data) + .set(cmd_data) } fn handle_command_updatefablabel( @@ -527,7 +527,7 @@ impl<'a> NocCluster<'a> { encoder .with_command(RespCommands::CertChainResp as _)? - .set(&cmd_data) + .set(cmd_data) } fn handle_command_csrrequest( diff --git a/matter/src/data_model/sdm/nw_commissioning.rs b/matter/src/data_model/sdm/nw_commissioning.rs index 7afff7a2..47ffe6ed 100644 --- a/matter/src/data_model/sdm/nw_commissioning.rs +++ b/matter/src/data_model/sdm/nw_commissioning.rs @@ -52,8 +52,16 @@ impl NwCommCluster { } impl Handler for NwCommCluster { - fn read(&self, _attr: &AttrDetails, _encoder: AttrDataEncoder) -> Result<(), Error> { - Err(Error::AttributeNotFound) + fn read(&self, attr: &AttrDetails, encoder: AttrDataEncoder) -> Result<(), Error> { + if let Some(writer) = encoder.with_dataver(self.data_ver.get())? { + if attr.is_system() { + CLUSTER.read(attr.attr_id, writer) + } else { + Err(Error::AttributeNotFound) + } + } else { + Ok(()) + } } } diff --git a/matter/src/data_model/system_model/access_control.rs b/matter/src/data_model/system_model/access_control.rs index 3980a434..ffba5e67 100644 --- a/matter/src/data_model/system_model/access_control.rs +++ b/matter/src/data_model/system_model/access_control.rs @@ -20,7 +20,7 @@ use core::convert::TryInto; use strum::{EnumDiscriminants, FromRepr}; -use crate::acl::{AclEntry, AclMgr}; +use crate::acl::{self, AclEntry, AclMgr}; use crate::data_model::objects::*; use crate::interaction_model::messages::ib::{attr_list_write, ListOperation}; use crate::tlv::{FromTLV, TLVElement, TagType, ToTLV}; @@ -116,9 +116,14 @@ impl<'a> AccessControlCluster<'a> { writer.complete() } - _ => { - error!("Attribute not yet supported: this shouldn't happen"); - Err(Error::AttributeNotFound) + Attributes::SubjectsPerEntry(codec) => { + codec.encode(writer, acl::SUBJECTS_PER_ENTRY as u16) + } + Attributes::TargetsPerEntry(codec) => { + codec.encode(writer, acl::TARGETS_PER_ENTRY as u16) + } + Attributes::EntriesPerFabric(codec) => { + codec.encode(writer, acl::ENTRIES_PER_FABRIC as u16) } } } @@ -365,7 +370,7 @@ mod tests { writebuf.as_slice() ); } - writebuf.reset(0); + writebuf.reset(); // Test 2, only single entry is read in the response with fabric filtering and fabric idx 1 { @@ -400,7 +405,7 @@ mod tests { writebuf.as_slice() ); } - writebuf.reset(0); + writebuf.reset(); // Test 3, only single entry is read in the response with fabric filtering and fabric idx 2 { diff --git a/matter/src/fabric.rs b/matter/src/fabric.rs index 42c55fdf..b7e2425a 100644 --- a/matter/src/fabric.rs +++ b/matter/src/fabric.rs @@ -495,7 +495,11 @@ impl FabricMgr { } #[cfg(feature = "nightly")] - pub async fn load_async(&mut self, mut psm: T, mdns_mgr: &mut MdnsMgr) -> Result<(), Error> + pub async fn load_async( + &mut self, + mut psm: T, + mdns_mgr: &mut MdnsMgr<'_>, + ) -> Result<(), Error> where T: crate::persist::asynch::AsyncPsm, { diff --git a/matter/src/interaction_model/core.rs b/matter/src/interaction_model/core.rs index 935740e6..162d64c6 100644 --- a/matter/src/interaction_model/core.rs +++ b/matter/src/interaction_model/core.rs @@ -21,14 +21,23 @@ use crate::{ data_model::core::DataHandler, error::*, tlv::{get_root_node_struct, print_tlv_list, FromTLV, TLVElement, TLVWriter, TagType, ToTLV}, - transport::{exchange::ExchangeCtx, packet::Packet, proto_ctx::ProtoCtx, session::Session}, + transport::{ + exchange::{Exchange, ExchangeCtx}, + packet::Packet, + proto_ctx::ProtoCtx, + session::Session, + }, }; use colored::Colorize; use log::{error, info}; use num; use num_derive::FromPrimitive; -use super::messages::msg::{self, InvReq, ReadReq, StatusResp, SubscribeReq, TimedReq, WriteReq}; +use super::messages::{ + ib::{AttrPath, DataVersionFilter}, + msg::{self, InvReq, ReadReq, StatusResp, SubscribeReq, TimedReq, WriteReq}, + GenericPath, +}; #[macro_export] macro_rules! cmd_enter { @@ -132,6 +141,14 @@ impl<'a, 'b> Transaction<'a, 'b> { } } + pub fn exch(&self) -> &Exchange { + self.ctx.exch + } + + pub fn exch_mut(&mut self) -> &mut Exchange { + self.ctx.exch + } + pub fn session(&self) -> &Session { self.ctx.sess.session() } @@ -182,17 +199,25 @@ impl<'a, 'b> Transaction<'a, 'b> { /* Interaction Model ID as per the Matter Spec */ const PROTO_ID_INTERACTION_MODEL: usize = 0x01; +const MAX_RESUME_PATHS: usize = 128; +const MAX_RESUME_DATAVER_FILTERS: usize = 128; + +// This is the amount of space we reserve for other things to be attached towards +// the end of long reads. +const LONG_READS_TLV_RESERVE_SIZE: usize = 24; + pub enum Interaction<'a> { Read(ReadReq<'a>), Write(WriteReq<'a>), Invoke(InvReq<'a>), Subscribe(SubscribeReq<'a>), - Status(StatusResp), Timed(TimedReq), + ResumeRead(ResumeReadReq), + ResumeSubscribe(ResumeSubscribeReq), } impl<'a> Interaction<'a> { - pub fn new(rx: &'a Packet) -> Result { + fn new(rx: &'a Packet, transaction: &mut Transaction) -> Result, Error> { let opcode: OpCode = num::FromPrimitive::from_u8(rx.get_proto_opcode()).ok_or(Error::Invalid)?; @@ -202,253 +227,485 @@ impl<'a> Interaction<'a> { print_tlv_list(rx_data); match opcode { - OpCode::ReadRequest => Ok(Self::Read(ReadReq::from_tlv(&get_root_node_struct( + OpCode::ReadRequest => Ok(Some(Self::Read(ReadReq::from_tlv(&get_root_node_struct( rx_data, - )?)?)), - OpCode::WriteRequest => Ok(Self::Write(WriteReq::from_tlv(&get_root_node_struct( - rx_data, - )?)?)), - OpCode::InvokeRequest => Ok(Self::Invoke(InvReq::from_tlv(&get_root_node_struct( - rx_data, - )?)?)), - OpCode::SubscribeRequest => Ok(Self::Subscribe(SubscribeReq::from_tlv( + )?)?))), + OpCode::WriteRequest => Ok(Some(Self::Write(WriteReq::from_tlv( &get_root_node_struct(rx_data)?, - )?)), - OpCode::StatusResponse => Ok(Self::Status(StatusResp::from_tlv( + )?))), + OpCode::InvokeRequest => Ok(Some(Self::Invoke(InvReq::from_tlv( &get_root_node_struct(rx_data)?, - )?)), - OpCode::TimedRequest => Ok(Self::Timed(TimedReq::from_tlv(&get_root_node_struct( - rx_data, - )?)?)), + )?))), + OpCode::SubscribeRequest => Ok(Some(Self::Subscribe(SubscribeReq::from_tlv( + &get_root_node_struct(rx_data)?, + )?))), + OpCode::StatusResponse => { + let resp = StatusResp::from_tlv(&get_root_node_struct(rx_data)?)?; + + if resp.status == IMStatusCode::Success { + if let Some(req) = transaction.exch_mut().take_suspended_read_req() { + Ok(Some(Self::ResumeRead(req))) + } else if let Some(req) = transaction.exch_mut().take_suspended_subscribe_req() + { + Ok(Some(Self::ResumeSubscribe(req))) + } else { + Ok(None) + } + } else { + Ok(None) + } + } + OpCode::TimedRequest => Ok(Some(Self::Timed(TimedReq::from_tlv( + &get_root_node_struct(rx_data)?, + )?))), _ => { - error!("Opcode Not Handled: {:?}", opcode); + error!("Opcode not handled: {:?}", opcode); Err(Error::InvalidOpcode) } } } - pub fn initiate_tx( - &self, + pub fn initiate( + rx: &'a Packet, + tx: &mut Packet, + transaction: &mut Transaction, + ) -> Result, Error> { + if let Some(interaction) = Self::new(rx, transaction)? { + let initiated = match &interaction { + Interaction::Read(req) => req.initiate(tx, transaction)?, + Interaction::Write(req) => req.initiate(tx, transaction)?, + Interaction::Invoke(req) => req.initiate(tx, transaction)?, + Interaction::Subscribe(req) => req.initiate(tx, transaction)?, + Interaction::Timed(req) => { + req.process(tx, transaction)?; + false + } + Interaction::ResumeRead(req) => req.initiate(tx, transaction)?, + Interaction::ResumeSubscribe(req) => req.initiate(tx, transaction)?, + }; + + Ok(initiated.then_some(interaction)) + } else { + Ok(None) + } + } + + fn create_status_response(tx: &mut Packet, status: IMStatusCode) -> Result<(), Error> { + tx.set_proto_id(PROTO_ID_INTERACTION_MODEL as u16); + tx.set_proto_opcode(OpCode::StatusResponse as u8); + + let mut tw = TLVWriter::new(tx.get_writebuf()?); + + let status = StatusResp { status }; + status.to_tlv(&mut tw, TagType::Anonymous) + } +} + +impl<'a> ReadReq<'a> { + fn suspend(self, resume_path: GenericPath) -> ResumeReadReq { + ResumeReadReq { + paths: self + .attr_requests + .iter() + .flat_map(|attr_requests| attr_requests.iter()) + .collect(), + filters: self + .dataver_filters + .iter() + .flat_map(|filters| filters.iter()) + .collect(), + fabric_filtered: self.fabric_filtered, + resume_path, + } + } + + fn initiate(&self, tx: &mut Packet, _transaction: &mut Transaction) -> Result { + tx.set_proto_id(PROTO_ID_INTERACTION_MODEL as u16); + tx.set_proto_opcode(OpCode::ReportData as u8); + + let mut tw = Self::reserve_long_read_space(tx)?; + + tw.start_struct(TagType::Anonymous)?; + + if self.attr_requests.is_some() { + tw.start_array(TagType::Context(msg::ReportDataTag::AttributeReports as u8))?; + } + + Ok(true) + } + + pub fn complete( + self, tx: &mut Packet, transaction: &mut Transaction, + resume_path: Option, ) -> Result { - let reply = match self { - Self::Read(request) => { - tx.set_proto_id(PROTO_ID_INTERACTION_MODEL as u16); - tx.set_proto_opcode(OpCode::ReportData as u8); + let mut tw = Self::restore_long_read_space(tx)?; - let mut tw = TLVWriter::new(tx.get_writebuf()?); + if self.attr_requests.is_some() { + tw.end_container()?; + } - tw.start_struct(TagType::Anonymous)?; + let more_chunks = if let Some(resume_path) = resume_path { + tw.bool( + TagType::Context(msg::ReportDataTag::MoreChunkedMsgs as u8), + true, + )?; - if request.attr_requests.is_some() { - tw.start_array(TagType::Context(msg::ReportDataTag::AttributeReports as u8))?; - } + transaction + .exch_mut() + .set_suspended_read_req(self.suspend(resume_path)); + true + } else { + false + }; - false - } - Self::Write(_) => { - if transaction.has_timed_out() { - Self::create_status_response(tx, IMStatusCode::Timeout)?; + tw.bool( + TagType::Context(msg::ReportDataTag::SupressResponse as u8), + !more_chunks, + )?; - transaction.complete(); - transaction.ctx.exch.close(); + tw.end_container()?; - true - } else { - tx.set_proto_id(PROTO_ID_INTERACTION_MODEL as u16); - tx.set_proto_opcode(OpCode::WriteResponse as u8); + if !more_chunks { + transaction.complete(); + } - let mut tw = TLVWriter::new(tx.get_writebuf()?); + Ok(true) + } - tw.start_struct(TagType::Anonymous)?; - tw.start_array(TagType::Context(msg::WriteRespTag::WriteResponses as u8))?; + fn reserve_long_read_space<'p, 'b>(tx: &'p mut Packet<'b>) -> Result, Error> { + let wb = tx.get_writebuf()?; + wb.shrink(LONG_READS_TLV_RESERVE_SIZE)?; - false - } - } - Self::Invoke(request) => { - if transaction.has_timed_out() { - Self::create_status_response(tx, IMStatusCode::Timeout)?; + Ok(TLVWriter::new(wb)) + } - transaction.complete(); - transaction.ctx.exch.close(); + fn restore_long_read_space<'p, 'b>(tx: &'p mut Packet<'b>) -> Result, Error> { + let wb = tx.get_writebuf()?; + wb.expand(LONG_READS_TLV_RESERVE_SIZE)?; - true - } else { - let timed_tx = transaction.get_timeout().map(|_| true); - let timed_request = request.timed_request.filter(|a| *a); + Ok(TLVWriter::new(wb)) + } +} - // Either both should be None, or both should be Some(true) - if timed_tx != timed_request { - Self::create_status_response(tx, IMStatusCode::TimedRequestMisMatch)?; +impl<'a> WriteReq<'a> { + fn initiate(&self, tx: &mut Packet, transaction: &mut Transaction) -> Result { + if transaction.has_timed_out() { + Interaction::create_status_response(tx, IMStatusCode::Timeout)?; - true - } else { - tx.set_proto_id(PROTO_ID_INTERACTION_MODEL as u16); - tx.set_proto_opcode(OpCode::InvokeResponse as u8); + transaction.complete(); + transaction.ctx.exch.close(); - let mut tw = TLVWriter::new(tx.get_writebuf()?); + Ok(false) + } else { + tx.set_proto_id(PROTO_ID_INTERACTION_MODEL as u16); + tx.set_proto_opcode(OpCode::WriteResponse as u8); - tw.start_struct(TagType::Anonymous)?; + let mut tw = TLVWriter::new(tx.get_writebuf()?); - // Suppress Response -> TODO: Need to revisit this for cases where we send a command back - tw.bool( - TagType::Context(msg::InvRespTag::SupressResponse as u8), - false, - )?; + tw.start_struct(TagType::Anonymous)?; + tw.start_array(TagType::Context(msg::WriteRespTag::WriteResponses as u8))?; - if request.inv_requests.is_some() { - tw.start_array(TagType::Context( - msg::InvRespTag::InvokeResponses as u8, - ))?; - } + Ok(true) + } + } - false - } - } - } - Self::Subscribe(request) => { + pub fn complete(self, tx: &mut Packet, transaction: &mut Transaction) -> Result { + let suppress = self.supress_response.unwrap_or_default(); + + let mut tw = TLVWriter::new(tx.get_writebuf()?); + + tw.end_container()?; + tw.end_container()?; + + transaction.complete(); + + Ok(if suppress { + error!("Supress response is set, is this the expected handling?"); + false + } else { + true + }) + } +} + +impl<'a> InvReq<'a> { + fn initiate(&self, tx: &mut Packet, transaction: &mut Transaction) -> Result { + if transaction.has_timed_out() { + Interaction::create_status_response(tx, IMStatusCode::Timeout)?; + + transaction.complete(); + transaction.ctx.exch.close(); + + Ok(false) + } else { + let timed_tx = transaction.get_timeout().map(|_| true); + let timed_request = self.timed_request.filter(|a| *a); + + // Either both should be None, or both should be Some(true) + if timed_tx != timed_request { + Interaction::create_status_response(tx, IMStatusCode::TimedRequestMisMatch)?; + + Ok(false) + } else { tx.set_proto_id(PROTO_ID_INTERACTION_MODEL as u16); - tx.set_proto_opcode(OpCode::ReportData as u8); + tx.set_proto_opcode(OpCode::InvokeResponse as u8); let mut tw = TLVWriter::new(tx.get_writebuf()?); tw.start_struct(TagType::Anonymous)?; - if request.attr_requests.is_some() { - tw.start_array(TagType::Context(msg::ReportDataTag::AttributeReports as u8))?; + // Suppress Response -> TODO: Need to revisit this for cases where we send a command back + tw.bool( + TagType::Context(msg::InvRespTag::SupressResponse as u8), + false, + )?; + + if self.inv_requests.is_some() { + tw.start_array(TagType::Context(msg::InvRespTag::InvokeResponses as u8))?; } - true + Ok(true) } - Self::Status(_) => { - tx.set_proto_id(PROTO_ID_INTERACTION_MODEL as u16); - tx.set_proto_opcode(OpCode::SubscribeResponse as u8); + } + } - let mut tw = TLVWriter::new(tx.get_writebuf()?); + pub fn complete(self, tx: &mut Packet, _transaction: &mut Transaction) -> Result { + let mut tw = TLVWriter::new(tx.get_writebuf()?); - tw.start_struct(TagType::Anonymous)?; + if self.inv_requests.is_some() { + tw.end_container()?; + } - true - } - Self::Timed(request) => { - tx.set_proto_id(PROTO_ID_INTERACTION_MODEL as u16); - tx.set_proto_opcode(OpCode::StatusResponse as u8); + tw.end_container()?; - let mut tw = TLVWriter::new(tx.get_writebuf()?); + Ok(true) + } +} - transaction.set_timeout(request.timeout.into()); +impl TimedReq { + pub fn process(&self, tx: &mut Packet, transaction: &mut Transaction) -> Result<(), Error> { + tx.set_proto_id(PROTO_ID_INTERACTION_MODEL as u16); + tx.set_proto_opcode(OpCode::StatusResponse as u8); - let status = StatusResp { - status: IMStatusCode::Success, - }; + let mut tw = TLVWriter::new(tx.get_writebuf()?); - status.to_tlv(&mut tw, TagType::Anonymous)?; + transaction.set_timeout(self.timeout.into()); - true - } + let status = StatusResp { + status: IMStatusCode::Success, }; - Ok(!reply) + status.to_tlv(&mut tw, TagType::Anonymous)?; + + Ok(()) } +} - pub fn complete_tx( - &self, +impl<'a> SubscribeReq<'a> { + fn suspend(&self, resume_path: Option) -> ResumeSubscribeReq { + ResumeSubscribeReq { + paths: self + .attr_requests + .iter() + .flat_map(|attr_requests| attr_requests.iter()) + .collect(), + filters: self + .dataver_filters + .iter() + .flat_map(|filters| filters.iter()) + .collect(), + fabric_filtered: self.fabric_filtered, + resume_path, + keep_subs: self.keep_subs, + min_int_floor: self.min_int_floor, + max_int_ceil: self.max_int_ceil, + } + } + + fn initiate(&self, tx: &mut Packet, _transaction: &mut Transaction) -> Result { + tx.set_proto_id(PROTO_ID_INTERACTION_MODEL as u16); + tx.set_proto_opcode(OpCode::ReportData as u8); + + let mut tw = ReadReq::reserve_long_read_space(tx)?; + + tw.start_struct(TagType::Anonymous)?; + + if self.attr_requests.is_some() { + tw.start_array(TagType::Context(msg::ReportDataTag::AttributeReports as u8))?; + } + + Ok(true) + } + + pub fn complete( + self, tx: &mut Packet, transaction: &mut Transaction, + resume_path: Option, ) -> Result { - let reply = match self { - Self::Read(request) => { - let mut tw = TLVWriter::new(tx.get_writebuf()?); - - if request.attr_requests.is_some() { - tw.end_container()?; - } + let mut tw = ReadReq::restore_long_read_space(tx)?; - // Suppress response always true for read interaction - tw.bool( - TagType::Context(msg::ReportDataTag::SupressResponse as u8), - true, - )?; + if self.attr_requests.is_some() { + tw.end_container()?; + } - tw.end_container()?; + if resume_path.is_some() { + tw.bool( + TagType::Context(msg::ReportDataTag::MoreChunkedMsgs as u8), + true, + )?; + } - transaction.complete(); + transaction + .exch_mut() + .set_suspended_subscribe_req(self.suspend(resume_path)); - true - } - Self::Write(request) => { - let suppress = request.supress_response.unwrap_or_default(); + tw.bool( + TagType::Context(msg::ReportDataTag::SupressResponse as u8), + false, + )?; - let mut tw = TLVWriter::new(tx.get_writebuf()?); + tw.end_container()?; - tw.end_container()?; - tw.end_container()?; + Ok(true) + } +} - transaction.complete(); +pub struct ResumeReadReq { + pub paths: heapless::Vec, + pub filters: heapless::Vec, + pub fabric_filtered: bool, + pub resume_path: GenericPath, +} - if suppress { - error!("Supress response is set, is this the expected handling?"); - false - } else { - true - } - } - Self::Invoke(request) => { - let mut tw = TLVWriter::new(tx.get_writebuf()?); +impl ResumeReadReq { + fn initiate(&self, tx: &mut Packet, _transaction: &mut Transaction) -> Result { + tx.set_proto_id(PROTO_ID_INTERACTION_MODEL as u16); + tx.set_proto_opcode(OpCode::ReportData as u8); - if request.inv_requests.is_some() { - tw.end_container()?; - } + let mut tw = ReadReq::reserve_long_read_space(tx)?; - tw.end_container()?; + tw.start_struct(TagType::Anonymous)?; - true - } - Self::Subscribe(request) => { - let mut tw = TLVWriter::new(tx.get_writebuf()?); + tw.start_array(TagType::Context(msg::ReportDataTag::AttributeReports as u8))?; - if request.attr_requests.is_some() { - tw.end_container()?; - } + Ok(true) + } - tw.end_container()?; + pub fn complete( + mut self, + tx: &mut Packet, + transaction: &mut Transaction, + resume_path: Option, + ) -> Result { + let mut tw = ReadReq::restore_long_read_space(tx)?; - true - } - Self::Status(_) => { - let mut tw = TLVWriter::new(tx.get_writebuf()?); + tw.end_container()?; - tw.end_container()?; + let continue_interaction = if let Some(resume_path) = resume_path { + tw.bool( + TagType::Context(msg::ReportDataTag::MoreChunkedMsgs as u8), + true, + )?; - true - } - Self::Timed(_) => false, + self.resume_path = resume_path; + transaction.exch_mut().set_suspended_read_req(self); + true + } else { + false }; - if reply { - info!("Sending response"); - print_tlv_list(tx.as_slice()); - } + tw.bool( + TagType::Context(msg::ReportDataTag::SupressResponse as u8), + !continue_interaction, + )?; - if transaction.is_terminate() { - transaction.ctx.exch.terminate(); - } else if transaction.is_complete() { - transaction.ctx.exch.close(); + tw.end_container()?; + + if !continue_interaction { + transaction.complete(); } Ok(true) } +} - fn create_status_response(tx: &mut Packet, status: IMStatusCode) -> Result<(), Error> { +pub struct ResumeSubscribeReq { + pub paths: heapless::Vec, + pub filters: heapless::Vec, + pub fabric_filtered: bool, + pub resume_path: Option, + pub keep_subs: bool, + pub min_int_floor: u16, + pub max_int_ceil: u16, +} + +impl ResumeSubscribeReq { + fn initiate(&self, tx: &mut Packet, _transaction: &mut Transaction) -> Result { tx.set_proto_id(PROTO_ID_INTERACTION_MODEL as u16); - tx.set_proto_opcode(OpCode::StatusResponse as u8); - let mut tw = TLVWriter::new(tx.get_writebuf()?); + if self.resume_path.is_some() { + tx.set_proto_opcode(OpCode::ReportData as u8); - let status = StatusResp { status }; - status.to_tlv(&mut tw, TagType::Anonymous) + let mut tw = ReadReq::reserve_long_read_space(tx)?; + + tw.start_struct(TagType::Anonymous)?; + + tw.start_array(TagType::Context(msg::ReportDataTag::AttributeReports as u8))?; + } else { + tx.set_proto_opcode(OpCode::SubscribeResponse as u8); + + // let mut tw = TLVWriter::new(tx.get_writebuf()?); + // tw.start_struct(TagType::Anonymous)?; + } + + Ok(true) + } + + pub fn complete( + mut self, + tx: &mut Packet, + transaction: &mut Transaction, + resume_path: Option, + ) -> Result { + if self.resume_path.is_none() && resume_path.is_some() { + panic!("Cannot resume subscribe"); + } + + if self.resume_path.is_some() { + // Completing a ReportData message + let mut tw = ReadReq::restore_long_read_space(tx)?; + + tw.end_container()?; + + if resume_path.is_some() { + tw.bool( + TagType::Context(msg::ReportDataTag::MoreChunkedMsgs as u8), + true, + )?; + } + + tw.bool( + TagType::Context(msg::ReportDataTag::SupressResponse as u8), + false, + )?; + + tw.end_container()?; + + self.resume_path = resume_path; + transaction.exch_mut().set_suspended_subscribe_req(self); + } else { + // Completing a SubscribeResponse message + + // let mut tw = TLVWriter::new(tx.get_writebuf()?); + // tw.end_container()?; + + transaction.complete(); + } + + Ok(true) } } @@ -472,15 +729,14 @@ where T: DataHandler, { pub fn handle<'a>(&mut self, ctx: &'a mut ProtoCtx) -> Result, Error> { - let interaction = Interaction::new(ctx.rx)?; let mut transaction = Transaction::new(&mut ctx.exch_ctx); - let reply = if interaction.initiate_tx(ctx.tx, &mut transaction)? { - self.0.handle(&interaction, ctx.tx, &mut transaction)?; - interaction.complete_tx(ctx.tx, &mut transaction)? - } else { - true - }; + let reply = + if let Some(interaction) = Interaction::initiate(ctx.rx, ctx.tx, &mut transaction)? { + self.0.handle(interaction, ctx.tx, &mut transaction)? + } else { + true + }; Ok(reply.then_some(ctx.tx.as_slice())) } @@ -495,17 +751,14 @@ where &mut self, ctx: &'a mut ProtoCtx<'_, '_>, ) -> Result, Error> { - let interaction = Interaction::new(ctx.rx)?; let mut transaction = Transaction::new(&mut ctx.exch_ctx); - let reply = if interaction.initiate_tx(ctx.tx, &mut transaction)? { - self.0 - .handle(&interaction, ctx.tx, &mut transaction) - .await?; - interaction.complete_tx(ctx.tx, &mut transaction)? - } else { - true - }; + let reply = + if let Some(interaction) = Interaction::initiate(ctx.rx, ctx.tx, &mut transaction)? { + self.0.handle(interaction, ctx.tx, &mut transaction).await? + } else { + true + }; Ok(reply.then_some(ctx.tx.as_slice())) } diff --git a/matter/src/secure_channel/case.rs b/matter/src/secure_channel/case.rs index 80802ede..f5b9cb04 100644 --- a/matter/src/secure_channel/case.rs +++ b/matter/src/secure_channel/case.rs @@ -328,7 +328,7 @@ impl<'a> Case<'a> { tw.end_container()?; let key = KeyPair::new_from_public(initiator_noc_cert.get_pubkey())?; - key.verify_msg(write_buf.into_slice(), sign)?; + key.verify_msg(write_buf.as_slice(), sign)?; Ok(()) } @@ -508,7 +508,7 @@ impl<'a> Case<'a> { cipher_text, cipher_text.len() - TAG_LEN, )?; - Ok(write_buf.into_slice().len()) + Ok(write_buf.as_slice().len()) } fn get_sigma2_sign( @@ -531,7 +531,7 @@ impl<'a> Case<'a> { tw.str8(TagType::Context(4), peer_pub_key)?; tw.end_container()?; //println!("TBS is {:x?}", write_buf.as_borrow_slice()); - fabric.sign_msg(write_buf.into_slice(), signature) + fabric.sign_msg(write_buf.as_slice(), signature) } } diff --git a/matter/src/transport/exchange.rs b/matter/src/transport/exchange.rs index e668a8d5..c28a5b2e 100644 --- a/matter/src/transport/exchange.rs +++ b/matter/src/transport/exchange.rs @@ -22,6 +22,7 @@ use core::time::Duration; use log::{error, info, trace}; use crate::error::Error; +use crate::interaction_model::core::{ResumeReadReq, ResumeSubscribeReq}; use crate::secure_channel; use crate::secure_channel::case::CaseSession; use crate::utils::epoch::Epoch; @@ -68,6 +69,8 @@ enum State { pub enum DataOption { CaseSession(CaseSession), Time(Duration), + SuspendedReadReq(ResumeReadReq), + SuspendedSubscibeReq(ResumeSubscribeReq), #[default] None, } @@ -124,18 +127,14 @@ impl Exchange { self.role } - pub fn is_data_none(&self) -> bool { - matches!(self.data, DataOption::None) + pub fn clear_data(&mut self) { + self.data = DataOption::None; } pub fn set_case_session(&mut self, session: CaseSession) { self.data = DataOption::CaseSession(session); } - pub fn clear_data(&mut self) { - self.data = DataOption::None; - } - pub fn get_case_session(&mut self) -> Option<&mut CaseSession> { if let DataOption::CaseSession(session) = &mut self.data { Some(session) @@ -154,6 +153,34 @@ impl Exchange { } } + pub fn set_suspended_read_req(&mut self, req: ResumeReadReq) { + self.data = DataOption::SuspendedReadReq(req); + } + + pub fn take_suspended_read_req(&mut self) -> Option { + let old = core::mem::replace(&mut self.data, DataOption::None); + if let DataOption::SuspendedReadReq(req) = old { + Some(req) + } else { + self.data = old; + None + } + } + + pub fn set_suspended_subscribe_req(&mut self, req: ResumeSubscribeReq) { + self.data = DataOption::SuspendedSubscibeReq(req); + } + + pub fn take_suspended_subscribe_req(&mut self) -> Option { + let old = core::mem::replace(&mut self.data, DataOption::None); + if let DataOption::SuspendedSubscibeReq(req) = old { + Some(req) + } else { + self.data = old; + None + } + } + pub fn set_data_time(&mut self, expiry_ts: Option) { if let Some(t) = expiry_ts { self.data = DataOption::Time(t); @@ -430,7 +457,7 @@ mod tests { error::Error, transport::{ network::Address, - packet::Packet, + packet::{Packet, MAX_TX_BUF_SIZE}, session::{CloneData, SessionMode, MAX_SESSIONS}, }, utils::{ @@ -505,7 +532,7 @@ mod tests { /// - The sessions are evicted in LRU /// - The exchanges associated with those sessions are evicted too fn test_sess_evict() { - let mut mgr = ExchangeMgr::new(sys_epoch, dummy_rand); // TODO + let mut mgr = ExchangeMgr::new(sys_epoch, dummy_rand); fill_sessions(&mut mgr, MAX_SESSIONS + 1); // Sessions are now full from local session id 1 to 16 @@ -531,7 +558,7 @@ mod tests { let result = mgr.add_session(&get_clone_data(new_peer_sess_id, new_local_sess_id)); assert!(matches!(result, Err(Error::NoSpace))); - let mut buf = [0; 1500]; + let mut buf = [0; MAX_TX_BUF_SIZE]; let tx = &mut Packet::new_tx(&mut buf); let evicted = mgr.evict_session(tx).unwrap(); assert!(evicted); diff --git a/matter/src/transport/packet.rs b/matter/src/transport/packet.rs index e39ac1c9..b2ca7aad 100644 --- a/matter/src/transport/packet.rs +++ b/matter/src/transport/packet.rs @@ -33,6 +33,7 @@ use super::{ }; pub const MAX_RX_BUF_SIZE: usize = 1583; +pub const MAX_TX_BUF_SIZE: usize = 1280 - 40/*IPV6 header size*/ - 8/*UDP header size*/; type Buffer = [u8; MAX_RX_BUF_SIZE]; // TODO: I am not very happy with this construction, need to find another way to do this diff --git a/matter/src/transport/proto_hdr.rs b/matter/src/transport/proto_hdr.rs index 96928ac2..fd392bd2 100644 --- a/matter/src/transport/proto_hdr.rs +++ b/matter/src/transport/proto_hdr.rs @@ -311,7 +311,7 @@ mod tests { encrypt_in_place(send_ctr, 0, &plain_hdr, &mut writebuf, &key).unwrap(); assert_eq!( - writebuf.into_slice(), + writebuf.as_slice(), [ 189, 83, 250, 121, 38, 87, 97, 17, 153, 78, 243, 20, 36, 11, 131, 142, 136, 165, 227, 107, 204, 129, 193, 153, 42, 131, 138, 254, 22, 190, 76, 244, 116, 45, 156, diff --git a/matter/src/transport/session.rs b/matter/src/transport/session.rs index d4c49852..95597e2f 100644 --- a/matter/src/transport/session.rs +++ b/matter/src/transport/session.rs @@ -267,7 +267,7 @@ impl Session { let mut tmp_buf = [0_u8; proto_hdr::max_proto_hdr_len()]; let mut write_buf = WriteBuf::new(&mut tmp_buf); tx.proto.encode(&mut write_buf)?; - tx.get_writebuf()?.prepend(write_buf.into_slice())?; + tx.get_writebuf()?.prepend(write_buf.as_slice())?; // Generate plain-text header if self.mode == SessionMode::PlainText { @@ -278,7 +278,7 @@ impl Session { let mut tmp_buf = [0_u8; plain_hdr::max_plain_hdr_len()]; let mut write_buf = WriteBuf::new(&mut tmp_buf); tx.plain.encode(&mut write_buf)?; - let plain_hdr_bytes = write_buf.into_slice(); + let plain_hdr_bytes = write_buf.as_slice(); trace!("unencrypted packet: {:x?}", tx.as_mut_slice()); let ctr = tx.plain.ctr; diff --git a/matter/src/utils/writebuf.rs b/matter/src/utils/writebuf.rs index fae44818..3adafe2f 100644 --- a/matter/src/utils/writebuf.rs +++ b/matter/src/utils/writebuf.rs @@ -18,44 +18,21 @@ use crate::error::*; use byteorder::{ByteOrder, LittleEndian}; -/// Shrink WriteBuf -/// -/// This Macro creates a new (child) WriteBuf which has a truncated slice end. -/// - It accepts a WriteBuf, and the size to reserve (truncate) towards the end. -/// - It returns the new (child) WriteBuf -#[macro_export] -macro_rules! wb_shrink { - ($orig_wb:ident, $reserve:ident) => {{ - let m_data = $orig_wb.empty_as_mut_slice(); - let m_wb = WriteBuf::new(m_data, m_data.len() - $reserve); - (m_wb) - }}; -} - -/// Unshrink WriteBuf -/// -/// This macro unshrinks the WriteBuf -/// - It accepts the original WriteBuf and the child WriteBuf (that was the result of wb_shrink) -/// After this call, the child WriteBuf shouldn't be used -#[macro_export] -macro_rules! wb_unshrink { - ($orig_wb:ident, $new_wb:ident) => {{ - let m_data_len = $new_wb.as_slice().len(); - $orig_wb.forward_tail_by(m_data_len); - }}; -} - #[derive(Debug)] pub struct WriteBuf<'a> { buf: &'a mut [u8], + buf_size: usize, start: usize, end: usize, } impl<'a> WriteBuf<'a> { pub fn new(buf: &'a mut [u8]) -> Self { + let buf_size = buf.len(); + Self { buf, + buf_size, start: 0, end: 0, } @@ -73,10 +50,6 @@ impl<'a> WriteBuf<'a> { self.end += new_offset } - pub fn into_slice(self) -> &'a [u8] { - &self.buf[self.start..self.end] - } - pub fn as_slice(&self) -> &[u8] { &self.buf[self.start..self.end] } @@ -86,20 +59,43 @@ impl<'a> WriteBuf<'a> { } pub fn empty_as_mut_slice(&mut self) -> &mut [u8] { - &mut self.buf[self.end..] + &mut self.buf[self.end..self.buf_size] } - pub fn reset(&mut self, reserve: usize) { - self.start = reserve; - self.end = reserve; + pub fn reset(&mut self) { + self.buf_size = self.buf.len(); + self.start = 0; + self.end = 0; } pub fn reserve(&mut self, reserve: usize) -> Result<(), Error> { - if self.end != 0 || self.start != 0 { - return Err(Error::Invalid); + if self.end != 0 || self.start != 0 || self.buf_size != self.buf.len() { + Err(Error::Invalid) + } else if reserve > self.buf_size { + Err(Error::NoSpace) + } else { + self.start = reserve; + self.end = reserve; + Ok(()) + } + } + + pub fn shrink(&mut self, with: usize) -> Result<(), Error> { + if self.end + with <= self.buf_size { + self.buf_size -= with; + Ok(()) + } else { + Err(Error::NoSpace) + } + } + + pub fn expand(&mut self, by: usize) -> Result<(), Error> { + if self.buf.len() - self.buf_size >= by { + self.buf_size += by; + Ok(()) + } else { + Err(Error::NoSpace) } - self.reset(reserve); - Ok(()) } pub fn prepend_with(&mut self, size: usize, f: F) -> Result<(), Error> @@ -125,7 +121,7 @@ impl<'a> WriteBuf<'a> { where F: FnOnce(&mut Self), { - if self.end + size <= self.buf.len() { + if self.end + size <= self.buf_size { f(self); self.end += size; return Ok(()); @@ -274,7 +270,7 @@ mod tests { buf.prepend(&new_slice).unwrap(); assert_eq!( - buf.into_slice(), + buf.as_slice(), [ 0xa, 0xb, 0xc, 1, 65, 0, 0xbe, 0xba, 0xfe, 0xca, 0xbe, 0xba, 0xfe, 0xca, 0xbe, 0xba, 0xfe, 0xca diff --git a/matter/tests/common/im_engine.rs b/matter/tests/common/im_engine.rs index 348ce74a..116ad50c 100644 --- a/matter/tests/common/im_engine.rs +++ b/matter/tests/common/im_engine.rs @@ -24,16 +24,20 @@ use matter::{ cluster_on_off::{self, OnOffCluster}, core::DataModel, device_types::{DEV_TYPE_ON_OFF_LIGHT, DEV_TYPE_ROOT_NODE}, - objects::{ChainedHandler, Endpoint, Node, Privilege}, + objects::{Endpoint, Node, Privilege}, root_endpoint::{self, RootEndpointHandler}, sdm::{ admin_commissioning, dev_att::{DataType, DevAttDataFetcher}, general_commissioning, noc, nw_commissioning, }, - system_model::access_control, + system_model::{ + access_control, + descriptor::{self, DescriptorCluster}, + }, }, error::Error, + handler_chain_type, interaction_model::core::{InteractionModel, OpCode}, mdns::Mdns, tlv::{TLVWriter, TagType, ToTLV}, @@ -41,6 +45,7 @@ use matter::{ transport::{ exchange::{self, Exchange, ExchangeCtx}, network::Address, + packet::MAX_RX_BUF_SIZE, proto_ctx::ProtoCtx, session::{CaseDetails, CloneData, NocCatIds, SessionMgr, SessionMode}, }, @@ -97,12 +102,9 @@ impl<'a> ImInput<'a> { } } -pub type DmHandler<'a> = ChainedHandler< - OnOffCluster, - ChainedHandler>>, ->; +pub type DmHandler<'a> = handler_chain_type!(OnOffCluster, EchoCluster, DescriptorCluster, EchoCluster | RootEndpointHandler<'a>); -pub fn matter<'a>(mdns: &'a mut dyn Mdns) -> Matter<'_> { +pub fn matter(mdns: &mut dyn Mdns) -> Matter<'_> { Matter::new(&BASIC_INFO, mdns, sys_epoch, dummy_rand) } @@ -132,6 +134,7 @@ impl<'a> ImEngine<'a> { Endpoint { id: 0, clusters: &[ + descriptor::CLUSTER, cluster_basic_information::CLUSTER, general_commissioning::CLUSTER, nw_commissioning::CLUSTER, @@ -144,13 +147,18 @@ impl<'a> ImEngine<'a> { }, Endpoint { id: 1, - clusters: &[echo_cluster::CLUSTER, cluster_on_off::CLUSTER], + clusters: &[ + descriptor::CLUSTER, + cluster_on_off::CLUSTER, + echo_cluster::CLUSTER, + ], device_type: DEV_TYPE_ON_OFF_LIGHT, }, ], }, root_endpoint::handler(0, &DummyDevAtt {}, matter) .chain(0, echo_cluster::ID, EchoCluster::new(2, *matter.borrow())) + .chain(1, descriptor::ID, DescriptorCluster::new(*matter.borrow())) .chain(1, echo_cluster::ID, EchoCluster::new(3, *matter.borrow())) .chain(1, cluster_on_off::ID, OnOffCluster::new(*matter.borrow())), ); @@ -164,7 +172,7 @@ impl<'a> ImEngine<'a> { pub fn echo_cluster(&self, endpoint: u16) -> &EchoCluster { match endpoint { - 0 => &self.im.0.handler.next.next.handler, + 0 => &self.im.0.handler.next.next.next.handler, 1 => &self.im.0.handler.next.handler, _ => panic!(), } @@ -196,8 +204,8 @@ impl<'a> ImEngine<'a> { sess, epoch: *self.matter.borrow(), }; - let mut tx_buf = [0; 1500]; - let mut rx_buf = [0; 1500]; + let mut rx_buf = [0; MAX_RX_BUF_SIZE]; + let mut tx_buf = [0; 1450]; // For the long read tests to run unchanged let mut rx = Packet::new_rx(&mut rx_buf); let mut tx = Packet::new_tx(&mut tx_buf); // Create fake rx packet diff --git a/matter/tests/data_model/long_reads.rs b/matter/tests/data_model/long_reads.rs index 9f7957aa..693f1dfa 100644 --- a/matter/tests/data_model/long_reads.rs +++ b/matter/tests/data_model/long_reads.rs @@ -30,11 +30,13 @@ use matter::{ }, messages::{msg::SubscribeReq, GenericPath}, }, + mdns::DummyMdns, tlv::{self, ElementType, FromTLV, TLVElement, TagType, ToTLV}, transport::{ exchange::{self, Exchange}, udp::MAX_RX_BUF_SIZE, }, + Matter, }; use crate::{ @@ -42,28 +44,28 @@ use crate::{ common::{ attributes::*, echo_cluster as echo, - im_engine::{ImEngine, ImInput}, + im_engine::{matter, ImEngine, ImInput}, }, }; -pub struct LongRead { - im_engine: ImEngine, +pub struct LongRead<'a> { + im_engine: ImEngine<'a>, } -impl LongRead { - pub fn new() -> Self { - let mut im_engine = ImEngine::new(); +impl<'a> LongRead<'a> { + pub fn new(matter: &'a Matter<'a>) -> Self { + let mut im_engine = ImEngine::new(matter); // Use the same exchange for all parts of the transaction im_engine.exch = Some(Exchange::new(1, 0, exchange::Role::Responder)); Self { im_engine } } - pub fn process<'a>( + pub fn process<'p>( &mut self, action: OpCode, data: &dyn ToTLV, - data_out: &'a mut [u8], - ) -> (u8, &'a mut [u8]) { + data_out: &'p mut [u8], + ) -> (u8, &'p [u8]) { let input = ImInput::new(action, data); let (response, output) = self.im_engine.process(&input, data_out); (response, output) @@ -82,49 +84,139 @@ fn wildcard_read_resp(part: u8) -> Vec> { attr_data!(0, 29, descriptor::Attributes::ClientList, dont_care), attr_data!(0, 40, GlobalElements::FeatureMap, dont_care), attr_data!(0, 40, GlobalElements::AttributeList, dont_care), - attr_data!(0, 40, basic_info::Attributes::DMRevision, dont_care), - attr_data!(0, 40, basic_info::Attributes::VendorId, dont_care), - attr_data!(0, 40, basic_info::Attributes::ProductId, dont_care), - attr_data!(0, 40, basic_info::Attributes::HwVer, dont_care), - attr_data!(0, 40, basic_info::Attributes::SwVer, dont_care), - attr_data!(0, 40, basic_info::Attributes::SwVerString, dont_care), - attr_data!(0, 40, basic_info::Attributes::SerialNo, dont_care), + attr_data!( + 0, + 40, + basic_info::AttributesDiscriminants::DMRevision, + dont_care + ), + attr_data!( + 0, + 40, + basic_info::AttributesDiscriminants::VendorId, + dont_care + ), + attr_data!( + 0, + 40, + basic_info::AttributesDiscriminants::ProductId, + dont_care + ), + attr_data!(0, 40, basic_info::AttributesDiscriminants::HwVer, dont_care), + attr_data!(0, 40, basic_info::AttributesDiscriminants::SwVer, dont_care), + attr_data!( + 0, + 40, + basic_info::AttributesDiscriminants::SwVerString, + dont_care + ), + attr_data!( + 0, + 40, + basic_info::AttributesDiscriminants::SerialNo, + dont_care + ), attr_data!(0, 48, GlobalElements::FeatureMap, dont_care), attr_data!(0, 48, GlobalElements::AttributeList, dont_care), - attr_data!(0, 48, gen_comm::Attributes::BreadCrumb, dont_care), - attr_data!(0, 48, gen_comm::Attributes::RegConfig, dont_care), - attr_data!(0, 48, gen_comm::Attributes::LocationCapability, dont_care), attr_data!( 0, 48, - gen_comm::Attributes::BasicCommissioningInfo, + gen_comm::AttributesDiscriminants::BreadCrumb, + dont_care + ), + attr_data!( + 0, + 48, + gen_comm::AttributesDiscriminants::RegConfig, + dont_care + ), + attr_data!( + 0, + 48, + gen_comm::AttributesDiscriminants::LocationCapability, + dont_care + ), + attr_data!( + 0, + 48, + gen_comm::AttributesDiscriminants::BasicCommissioningInfo, dont_care ), attr_data!(0, 49, GlobalElements::FeatureMap, dont_care), attr_data!(0, 49, GlobalElements::AttributeList, dont_care), attr_data!(0, 60, GlobalElements::FeatureMap, dont_care), attr_data!(0, 60, GlobalElements::AttributeList, dont_care), - attr_data!(0, 60, adm_comm::Attributes::WindowStatus, dont_care), - attr_data!(0, 60, adm_comm::Attributes::AdminFabricIndex, dont_care), - attr_data!(0, 60, adm_comm::Attributes::AdminVendorId, dont_care), + attr_data!( + 0, + 60, + adm_comm::AttributesDiscriminants::WindowStatus, + dont_care + ), + attr_data!( + 0, + 60, + adm_comm::AttributesDiscriminants::AdminFabricIndex, + dont_care + ), + attr_data!( + 0, + 60, + adm_comm::AttributesDiscriminants::AdminVendorId, + dont_care + ), attr_data!(0, 62, GlobalElements::FeatureMap, dont_care), attr_data!(0, 62, GlobalElements::AttributeList, dont_care), - attr_data!(0, 62, noc::Attributes::CurrentFabricIndex, dont_care), - attr_data!(0, 62, noc::Attributes::Fabrics, dont_care), - attr_data!(0, 62, noc::Attributes::SupportedFabrics, dont_care), - attr_data!(0, 62, noc::Attributes::CommissionedFabrics, dont_care), + attr_data!( + 0, + 62, + noc::AttributesDiscriminants::CurrentFabricIndex, + dont_care + ), + attr_data!(0, 62, noc::AttributesDiscriminants::Fabrics, dont_care), + attr_data!( + 0, + 62, + noc::AttributesDiscriminants::SupportedFabrics, + dont_care + ), + attr_data!( + 0, + 62, + noc::AttributesDiscriminants::CommissionedFabrics, + dont_care + ), attr_data!(0, 31, GlobalElements::FeatureMap, dont_care), attr_data!(0, 31, GlobalElements::AttributeList, dont_care), - attr_data!(0, 31, acl::Attributes::Acl, dont_care), - attr_data!(0, 31, acl::Attributes::Extension, dont_care), - attr_data!(0, 31, acl::Attributes::SubjectsPerEntry, dont_care), - attr_data!(0, 31, acl::Attributes::TargetsPerEntry, dont_care), - attr_data!(0, 31, acl::Attributes::EntriesPerFabric, dont_care), + attr_data!(0, 31, acl::AttributesDiscriminants::Acl, dont_care), + attr_data!(0, 31, acl::AttributesDiscriminants::Extension, dont_care), + attr_data!( + 0, + 31, + acl::AttributesDiscriminants::SubjectsPerEntry, + dont_care + ), + attr_data!( + 0, + 31, + acl::AttributesDiscriminants::TargetsPerEntry, + dont_care + ), + attr_data!( + 0, + 31, + acl::AttributesDiscriminants::EntriesPerFabric, + dont_care + ), attr_data!(0, echo::ID, GlobalElements::FeatureMap, dont_care), attr_data!(0, echo::ID, GlobalElements::AttributeList, dont_care), - attr_data!(0, echo::ID, echo::Attributes::Att1, dont_care), - attr_data!(0, echo::ID, echo::Attributes::Att2, dont_care), - attr_data!(0, echo::ID, echo::Attributes::AttCustom, dont_care), + attr_data!(0, echo::ID, echo::AttributesDiscriminants::Att1, dont_care), + attr_data!(0, echo::ID, echo::AttributesDiscriminants::Att2, dont_care), + attr_data!( + 0, + echo::ID, + echo::AttributesDiscriminants::AttCustom, + dont_care + ), attr_data!(1, 29, GlobalElements::FeatureMap, dont_care), attr_data!(1, 29, GlobalElements::AttributeList, dont_care), attr_data!(1, 29, descriptor::Attributes::DeviceTypeList, dont_care), @@ -136,12 +228,17 @@ fn wildcard_read_resp(part: u8) -> Vec> { attr_data!(1, 29, descriptor::Attributes::ClientList, dont_care), attr_data!(1, 6, GlobalElements::FeatureMap, dont_care), attr_data!(1, 6, GlobalElements::AttributeList, dont_care), - attr_data!(1, 6, onoff::Attributes::OnOff, dont_care), + attr_data!(1, 6, onoff::AttributesDiscriminants::OnOff, dont_care), attr_data!(1, echo::ID, GlobalElements::FeatureMap, dont_care), attr_data!(1, echo::ID, GlobalElements::AttributeList, dont_care), - attr_data!(1, echo::ID, echo::Attributes::Att1, dont_care), - attr_data!(1, echo::ID, echo::Attributes::Att2, dont_care), - attr_data!(1, echo::ID, echo::Attributes::AttCustom, dont_care), + attr_data!(1, echo::ID, echo::AttributesDiscriminants::Att1, dont_care), + attr_data!(1, echo::ID, echo::AttributesDiscriminants::Att2, dont_care), + attr_data!( + 1, + echo::ID, + echo::AttributesDiscriminants::AttCustom, + dont_care + ), ]; if part == 1 { @@ -155,7 +252,9 @@ fn wildcard_read_resp(part: u8) -> Vec> { fn test_long_read_success() { // Read the entire attribute database, which requires 2 reads to complete let _ = env_logger::try_init(); - let mut lr = LongRead::new(); + let mut mdns = DummyMdns; + let matter = matter(&mut mdns); + let mut lr = LongRead::new(&matter); let mut output = [0_u8; MAX_RX_BUF_SIZE + 100]; let wc_path = GenericPath::new(None, None, None); @@ -187,7 +286,9 @@ fn test_long_read_success() { fn test_long_read_subscription_success() { // Subscribe to the entire attribute database, which requires 2 reads to complete let _ = env_logger::try_init(); - let mut lr = LongRead::new(); + let mut mdns = DummyMdns; + let matter = matter(&mut mdns); + let mut lr = LongRead::new(&matter); let mut output = [0_u8; MAX_RX_BUF_SIZE + 100]; let wc_path = GenericPath::new(None, None, None); @@ -219,6 +320,6 @@ fn test_long_read_subscription_success() { tlv::print_tlv_list(out_data); let root = tlv::get_root_node_struct(out_data).unwrap(); let subs_resp = SubscribeResp::from_tlv(&root).unwrap(); - assert_eq!(out_code, OpCode::SubscriptResponse as u8); + assert_eq!(out_code, OpCode::SubscribeResponse as u8); assert_eq!(subs_resp.subs_id, 1); } diff --git a/matter/tests/data_model_tests.rs b/matter/tests/data_model_tests.rs index 803c4c5e..392909fa 100644 --- a/matter/tests/data_model_tests.rs +++ b/matter/tests/data_model_tests.rs @@ -22,6 +22,6 @@ mod data_model { mod attribute_lists; mod attributes; mod commands; - // TODO mod long_reads; + mod long_reads; mod timed_requests; } diff --git a/matter/tests/interaction_model.rs b/matter/tests/interaction_model.rs index 07d114e9..b73ab46f 100644 --- a/matter/tests/interaction_model.rs +++ b/matter/tests/interaction_model.rs @@ -25,6 +25,8 @@ use matter::transport::exchange::Exchange; use matter::transport::exchange::ExchangeCtx; use matter::transport::network::Address; use matter::transport::packet::Packet; +use matter::transport::packet::MAX_RX_BUF_SIZE; +use matter::transport::packet::MAX_TX_BUF_SIZE; use matter::transport::proto_ctx::ProtoCtx; use matter::transport::session::SessionMgr; use matter::utils::epoch::dummy_epoch; @@ -52,30 +54,27 @@ impl DataModel { impl DataHandler for DataModel { fn handle( &mut self, - interaction: &Interaction, + interaction: Interaction, _tx: &mut Packet, _transaction: &mut Transaction, ) -> Result { - match interaction { - Interaction::Invoke(req) => { - if let Some(inv_requests) = &req.inv_requests { - for i in inv_requests.iter() { - let data = if let Some(data) = i.data.unwrap_tlv() { - data - } else { - continue; - }; - let cmd_path_ib = i.path; - let mut common_data = &mut self.node; - common_data.endpoint = cmd_path_ib.path.endpoint.unwrap_or(1); - common_data.cluster = cmd_path_ib.path.cluster.unwrap_or(0); - common_data.command = cmd_path_ib.path.leaf.unwrap_or(0) as u16; - data.confirm_struct().unwrap(); - common_data.variable = data.find_tag(0).unwrap().u8().unwrap(); - } + if let Interaction::Invoke(req) = interaction { + if let Some(inv_requests) = &req.inv_requests { + for i in inv_requests.iter() { + let data = if let Some(data) = i.data.unwrap_tlv() { + data + } else { + continue; + }; + let cmd_path_ib = i.path; + let mut common_data = &mut self.node; + common_data.endpoint = cmd_path_ib.path.endpoint.unwrap_or(1); + common_data.cluster = cmd_path_ib.path.cluster.unwrap_or(0); + common_data.command = cmd_path_ib.path.leaf.unwrap_or(0) as u16; + data.confirm_struct().unwrap(); + common_data.variable = data.find_tag(0).unwrap().u8().unwrap(); } } - _ => (), } Ok(false) @@ -109,8 +108,8 @@ fn handle_data(action: OpCode, data_in: &[u8], data_out: &mut [u8]) -> (DataMode sess, epoch: dummy_epoch, }; - let mut rx_buf = [0; 1500]; - let mut tx_buf = [0; 1500]; + let mut rx_buf = [0; MAX_RX_BUF_SIZE]; + let mut tx_buf = [0; MAX_TX_BUF_SIZE]; let mut rx = Packet::new_rx(&mut rx_buf); let mut tx = Packet::new_tx(&mut tx_buf); // Create fake rx packet diff --git a/matter_macro_derive/src/lib.rs b/matter_macro_derive/src/lib.rs index 0fc358f8..a1fc5532 100644 --- a/matter_macro_derive/src/lib.rs +++ b/matter_macro_derive/src/lib.rs @@ -138,11 +138,20 @@ fn gen_totlv_for_struct( let expanded = quote! { impl #generics ToTLV for #struct_name #generics { fn to_tlv(&self, tw: &mut TLVWriter, tag_type: TagType) -> Result<(), Error> { - tw. #datatype (tag_type)?; - #( - self.#idents.to_tlv(tw, TagType::Context(#tags))?; - )* - tw.end_container() + let anchor = tw.get_tail(); + + if let Err(err) = (|| { + tw. #datatype (tag_type)?; + #( + self.#idents.to_tlv(tw, TagType::Context(#tags))?; + )* + tw.end_container() + })() { + tw.rewind_to(anchor); + Err(err) + } else { + Ok(()) + } } } }; @@ -179,17 +188,26 @@ fn gen_totlv_for_enum( } let expanded = quote! { - impl #generics ToTLV for #enum_name #generics { - fn to_tlv(&self, tw: &mut TLVWriter, tag_type: TagType) -> Result<(), Error> { - tw.start_struct(tag_type)?; - match self { - #( - Self::#variant_names(c) => { c.to_tlv(tw, TagType::Context(#tags))?; }, - )* - } - tw.end_container() - } - } + impl #generics ToTLV for #enum_name #generics { + fn to_tlv(&self, tw: &mut TLVWriter, tag_type: TagType) -> Result<(), Error> { + let anchor = tw.get_tail(); + + if let Err(err) = (|| { + tw.start_struct(tag_type)?; + match self { + #( + Self::#variant_names(c) => { c.to_tlv(tw, TagType::Context(#tags))?; }, + )* + } + tw.end_container() + })() { + tw.rewind_to(anchor); + Err(err) + } else { + Ok(()) + } + } + } }; // panic!("Expanded to {}", expanded); From 89aab6f444155e2964fbe20fdc398c0c0151222f Mon Sep 17 00:00:00 2001 From: ivmarkov Date: Sun, 23 Apr 2023 14:53:48 +0000 Subject: [PATCH 04/72] Remove allocations from Base38 and QR calc --- matter/Cargo.toml | 2 +- matter/src/codec/base38.rs | 207 ++++++++++++++++++----------------- matter/src/core.rs | 13 ++- matter/src/group_keys.rs | 30 ------ matter/src/pairing/mod.rs | 6 +- matter/src/pairing/qr.rs | 213 ++++++++++++++++++------------------- 6 files changed, 231 insertions(+), 240 deletions(-) diff --git a/matter/Cargo.toml b/matter/Cargo.toml index cf9497b4..0987e10d 100644 --- a/matter/Cargo.toml +++ b/matter/Cargo.toml @@ -15,7 +15,7 @@ name = "matter" path = "src/lib.rs" [features] -default = ["std", "crypto_mbedtls", "nightly"] +default = ["std", "crypto_mbedtls"] std = [] nightly = [] crypto_openssl = ["openssl", "foreign-types", "hmac", "sha2"] diff --git a/matter/src/codec/base38.rs b/matter/src/codec/base38.rs index 7b7e7587..14114e63 100644 --- a/matter/src/codec/base38.rs +++ b/matter/src/codec/base38.rs @@ -17,10 +17,6 @@ //! Base38 encoding and decoding functions. -extern crate alloc; - -use alloc::{string::String, vec::Vec}; - use crate::error::Error; const BASE38_CHARS: [char; 38] = [ @@ -81,60 +77,68 @@ const DECODE_BASE38: [u8; 46] = [ 35, // 'Z', =90 ]; -const BASE38_CHARACTERS_NEEDED_IN_NBYTES_CHUNK: [u8; 3] = [2, 4, 5]; const RADIX: u32 = BASE38_CHARS.len() as u32; /// Encode a byte array into a base38 string. /// /// # Arguments /// * `bytes` - byte array to encode -/// * `length` - optional length of the byte array to encode. If not specified, the entire byte array is encoded. -pub fn encode(bytes: &[u8], length: Option) -> String { - let mut offset = 0; - let mut result = String::new(); - - // if length is specified, use it, otherwise use the length of the byte array - // if length is specified but is greater than the length of the byte array, use the length of the byte array - let b_len = bytes.len(); - let length = length.map(|l| l.min(b_len)).unwrap_or(b_len); - - while offset < length { - let remaining = length - offset; - match remaining.cmp(&2) { - core::cmp::Ordering::Greater => { - result.push_str(&encode_base38( - ((bytes[offset + 2] as u32) << 16) - | ((bytes[offset + 1] as u32) << 8) - | (bytes[offset] as u32), - 5, - )); - offset += 3; - } - core::cmp::Ordering::Equal => { - result.push_str(&encode_base38( - ((bytes[offset + 1] as u32) << 8) | (bytes[offset] as u32), - 4, - )); - break; - } - core::cmp::Ordering::Less => { - result.push_str(&encode_base38(bytes[offset] as u32, 2)); - break; - } - } +pub fn encode_string(bytes: &[u8]) -> Result, Error> { + let mut string = heapless::String::new(); + for c in encode(bytes) { + string.push(c).map_err(|_| Error::NoSpace)?; } - result + Ok(string) } -fn encode_base38(mut value: u32, char_count: u8) -> String { - let mut result = String::new(); - for _ in 0..char_count { - let remainder = value % 38; - result.push(BASE38_CHARS[remainder as usize]); - value = (value - remainder) / 38; +pub fn encode(bytes: &[u8]) -> impl Iterator + '_ { + (0..bytes.len() / 3) + .flat_map(move |index| { + let offset = index * 3; + + encode_base38( + ((bytes[offset + 2] as u32) << 16) + | ((bytes[offset + 1] as u32) << 8) + | (bytes[offset] as u32), + 5, + ) + }) + .chain( + core::iter::once(bytes.len() % 3).flat_map(move |remainder| { + let offset = bytes.len() / 3 * 3; + + match remainder { + 2 => encode_base38( + ((bytes[offset + 1] as u32) << 8) | (bytes[offset] as u32), + 4, + ), + 1 => encode_base38(bytes[offset] as u32, 2), + _ => encode_base38(0, 0), + } + }), + ) +} + +fn encode_base38(mut value: u32, repeat: usize) -> impl Iterator { + (0..repeat).map(move |_| { + let remainder = value % RADIX; + let c = BASE38_CHARS[remainder as usize]; + + value = (value - remainder) / RADIX; + + c + }) +} + +pub fn decode_vec(base38_str: &str) -> Result, Error> { + let mut vec = heapless::Vec::new(); + + for byte in decode(base38_str) { + vec.push(byte?).map_err(|_| Error::NoSpace)?; } - result + + Ok(vec) } /// Decode a base38-encoded string into a byte slice @@ -142,57 +146,64 @@ fn encode_base38(mut value: u32, char_count: u8) -> String { /// # Arguments /// * `base38_str` - base38-encoded string to decode /// -/// Fails if the string contains invalid characters -pub fn decode(base38_str: &str) -> Result, Error> { - let mut result = Vec::new(); - let mut base38_characters_number: usize = base38_str.len(); - let mut decoded_base38_characters: usize = 0; - - while base38_characters_number > 0 { - let base38_characters_in_chunk: usize; - let bytes_in_decoded_chunk: usize; - - if base38_characters_number >= BASE38_CHARACTERS_NEEDED_IN_NBYTES_CHUNK[2] as usize { - base38_characters_in_chunk = BASE38_CHARACTERS_NEEDED_IN_NBYTES_CHUNK[2] as usize; - bytes_in_decoded_chunk = 3; - } else if base38_characters_number == BASE38_CHARACTERS_NEEDED_IN_NBYTES_CHUNK[1] as usize { - base38_characters_in_chunk = BASE38_CHARACTERS_NEEDED_IN_NBYTES_CHUNK[1] as usize; - bytes_in_decoded_chunk = 2; - } else if base38_characters_number == BASE38_CHARACTERS_NEEDED_IN_NBYTES_CHUNK[0] as usize { - base38_characters_in_chunk = BASE38_CHARACTERS_NEEDED_IN_NBYTES_CHUNK[0] as usize; - bytes_in_decoded_chunk = 1; - } else { - return Err(Error::InvalidData); - } - - let mut value = 0u32; - - for i in (1..=base38_characters_in_chunk).rev() { - let mut base38_chars = base38_str.chars(); - let v = decode_char(base38_chars.nth(decoded_base38_characters + i - 1).unwrap())?; +/// Fails if the string contains invalid characters or if the supplied buffer is too small to fit the decoded data +pub fn decode(base38_str: &str) -> impl Iterator> + '_ { + let stru = base38_str.as_bytes(); + + (0..stru.len() / 5) + .flat_map(move |index| { + let offset = index * 5; + decode_base38(&stru[offset..offset + 5]) + }) + .chain({ + let offset = stru.len() / 5 * 5; + decode_base38(&stru[offset..]) + }) + .take_while(Result::is_ok) +} - value = value * RADIX + v as u32; +fn decode_base38(chars: &[u8]) -> impl Iterator> { + let mut value = 0u32; + let mut cerr = None; + + let repeat = match chars.len() { + 5 => 3, + 4 => 2, + 2 => 1, + 0 => 0, + _ => -1, + }; + + if repeat >= 0 { + for c in chars.iter().rev() { + match decode_char(*c) { + Ok(v) => value = value * RADIX + v as u32, + Err(err) => { + cerr = Some(err); + break; + } + } } + } else { + cerr = Some(Error::InvalidData) + } - decoded_base38_characters += base38_characters_in_chunk; - base38_characters_number -= base38_characters_in_chunk; - - for _i in 0..bytes_in_decoded_chunk { - result.push(value as u8); - value >>= 8; - } + (0..repeat) + .map(move |_| { + if let Some(err) = cerr { + Err(err) + } else { + let byte = (value & 0xff) as u8; - if value > 0 { - // encoded value is too big to represent a correct chunk of size 1, 2 or 3 bytes - return Err(Error::InvalidArgument); - } - } + value >>= 8; - Ok(result) + Ok(byte) + } + }) + .take_while(Result::is_ok) } -fn decode_char(c: char) -> Result { - let c = c as u8; +fn decode_char(c: u8) -> Result { if !(45..=90).contains(&c) { return Err(Error::InvalidData); } @@ -215,15 +226,17 @@ mod tests { #[test] fn can_base38_encode() { - assert_eq!(encode(&DECODED, None), ENCODED); - assert_eq!(encode(&DECODED, Some(11)), ENCODED); - - // length is greater than the length of the byte array - assert_eq!(encode(&DECODED, Some(12)), ENCODED); + assert_eq!( + encode_string::<{ ENCODED.len() }>(&DECODED).unwrap(), + ENCODED + ); } #[test] fn can_base38_decode() { - assert_eq!(decode(ENCODED).expect("can not decode base38"), DECODED); + assert_eq!( + decode_vec::<{ DECODED.len() }>(ENCODED).expect("Cannot decode base38"), + DECODED + ); } } diff --git a/matter/src/core.rs b/matter/src/core.rs index 7b853b96..1c3eb25a 100644 --- a/matter/src/core.rs +++ b/matter/src/core.rs @@ -84,10 +84,19 @@ impl<'a> Matter<'a> { self.dev_det } - pub fn start(&mut self, dev_comm: CommissioningData) -> Result<(), Error> { + pub fn start( + &mut self, + dev_comm: CommissioningData, + buf: &mut [u8], + ) -> Result<(), Error> { let open_comm_window = self.fabric_mgr.borrow().is_empty(); if open_comm_window { - print_pairing_code_and_qr(self.dev_det, &dev_comm, DiscoveryCapabilities::default()); + print_pairing_code_and_qr::( + self.dev_det, + &dev_comm, + DiscoveryCapabilities::default(), + buf, + ); self.pase_mgr.borrow_mut().enable_pase_session( dev_comm.verifier, diff --git a/matter/src/group_keys.rs b/matter/src/group_keys.rs index c4dfaafa..1dc1c405 100644 --- a/matter/src/group_keys.rs +++ b/matter/src/group_keys.rs @@ -15,38 +15,8 @@ * limitations under the License. */ -use alloc::sync::Arc; -use std::sync::{Mutex, Once}; - use crate::{crypto, error::Error}; -extern crate alloc; - -// This is just makeshift implementation for now, not used anywhere -pub struct GroupKeys {} - -static mut G_GRP_KEYS: Option>> = None; -static INIT: Once = Once::new(); - -impl GroupKeys { - fn new() -> Self { - Self {} - } - - pub fn get() -> Result>, Error> { - unsafe { - INIT.call_once(|| { - G_GRP_KEYS = Some(Arc::new(Mutex::new(GroupKeys::new()))); - }); - Ok(G_GRP_KEYS.as_ref().ok_or(Error::Invalid)?.clone()) - } - } - - pub fn insert_key() -> Result<(), Error> { - Ok(()) - } -} - #[derive(Debug, Default)] pub struct KeySet { pub epoch_key: [u8; crypto::SYMM_KEY_LEN_BYTES], diff --git a/matter/src/pairing/mod.rs b/matter/src/pairing/mod.rs index d75e8f6d..ee5aaffa 100644 --- a/matter/src/pairing/mod.rs +++ b/matter/src/pairing/mod.rs @@ -82,14 +82,16 @@ impl DiscoveryCapabilities { } /// Prepares and prints the pairing code and the QR code for easy pairing. -pub fn print_pairing_code_and_qr( +pub fn print_pairing_code_and_qr( dev_det: &BasicInfoConfig, comm_data: &CommissioningData, discovery_capabilities: DiscoveryCapabilities, + buf: &mut [u8], ) { let pairing_code = compute_pairing_code(comm_data); let qr_code_data = QrSetupPayload::new(dev_det, comm_data, discovery_capabilities); - let data_str = payload_base38_representation(&qr_code_data).expect("Failed to encode"); + let data_str = + payload_base38_representation::(&qr_code_data, buf).expect("Failed to encode"); pretty_print_pairing_code(&pairing_code); print_qr_code(&data_str); diff --git a/matter/src/pairing/qr.rs b/matter/src/pairing/qr.rs index f1d844a5..70b668d7 100644 --- a/matter/src/pairing/qr.rs +++ b/matter/src/pairing/qr.rs @@ -15,8 +15,6 @@ * limitations under the License. */ -use heapless::FnvIndexMap; - use crate::{ tlv::{TLVWriter, TagType}, utils::writebuf::WriteBuf, @@ -45,6 +43,7 @@ const TOTAL_PAYLOAD_DATA_SIZE_IN_BITS: usize = VERSION_FIELD_LENGTH_IN_BITS + PAYLOAD_DISCRIMINATOR_FIELD_LENGTH_IN_BITS + SETUP_PINCODE_FIELD_LENGTH_IN_BITS + PADDING_FIELD_LENGTH_IN_BITS; + const TOTAL_PAYLOAD_DATA_SIZE_IN_BYTES: usize = TOTAL_PAYLOAD_DATA_SIZE_IN_BITS / 8; // Spec 5.1.4.2 CHIP-Common Reserved Tags @@ -80,8 +79,8 @@ pub struct QrSetupPayload<'data> { discovery_capabilities: DiscoveryCapabilities, dev_det: &'data BasicInfoConfig<'data>, comm_data: &'data CommissioningData, - // we use a BTreeMap to keep the order of the optional data stable - optional_data: heapless::FnvIndexMap, + // The vec is ordered by the tag of OptionalQRCodeInfo + optional_data: heapless::Vec, } impl<'data> QrSetupPayload<'data> { @@ -98,7 +97,7 @@ impl<'data> QrSetupPayload<'data> { discovery_capabilities, dev_det, comm_data, - optional_data: FnvIndexMap::new(), + optional_data: heapless::Vec::new(), }; if !dev_det.serial_no.is_empty() { @@ -132,15 +131,11 @@ impl<'data> QrSetupPayload<'data> { /// * `tag` - tag number in the [0x80-0xFF] range /// * `data` - Data to add pub fn add_optional_vendor_data(&mut self, tag: u8, data: QRCodeInfoType) -> Result<(), Error> { - if !is_vendor_tag(tag) { - return Err(Error::InvalidArgument); + if is_vendor_tag(tag) { + self.add_optional_data(tag, data) + } else { + Err(Error::InvalidArgument) } - - self.optional_data - .insert(tag, OptionalQRCodeInfo { tag, data }) - .map_err(|_| Error::NoSpace)?; - - Ok(()) } /// A function to add an optional QR Code info CHIP object @@ -152,18 +147,26 @@ impl<'data> QrSetupPayload<'data> { tag: u8, data: QRCodeInfoType, ) -> Result<(), Error> { - if !is_common_tag(tag) { - return Err(Error::InvalidArgument); + if is_common_tag(tag) { + self.add_optional_data(tag, data) + } else { + Err(Error::InvalidArgument) } + } - self.optional_data - .insert(tag, OptionalQRCodeInfo { tag, data }) - .map_err(|_| Error::NoSpace)?; + fn add_optional_data(&mut self, tag: u8, data: QRCodeInfoType) -> Result<(), Error> { + let item = OptionalQRCodeInfo { tag, data }; + let index = self.optional_data.iter().position(|info| tag < info.tag); - Ok(()) + if let Some(index) = index { + self.optional_data.insert(index, item) + } else { + self.optional_data.push(item) + } + .map_err(|_| Error::NoSpace) } - pub fn get_all_optional_data(&self) -> &FnvIndexMap { + pub fn get_all_optional_data(&self) -> &[OptionalQRCodeInfo] { &self.optional_data } @@ -249,35 +252,26 @@ pub enum CommissionningFlowType { Custom = 2, } -struct TlvData { - max_data_length_in_bytes: u32, - data_length_in_bytes: Option, - data: Option>, -} +pub(super) fn payload_base38_representation( + payload: &QrSetupPayload, + buf: &mut [u8], +) -> Result, Error> { + if payload.is_valid() { + let (bits_buf, tlv_buf) = if payload.has_tlv() { + let (bits_buf, tlv_buf) = buf.split_at_mut(buf.len() / 2); -pub(super) fn payload_base38_representation(payload: &QrSetupPayload) -> Result { - let (mut bits, tlv_data) = if payload.has_tlv() { - let buffer_size = estimate_buffer_size(payload)?; - ( - vec![0; buffer_size], - Some(TlvData { - max_data_length_in_bytes: buffer_size as u32, - data_length_in_bytes: None, - data: None, - }), - ) - } else { - (vec![0; TOTAL_PAYLOAD_DATA_SIZE_IN_BYTES], None) - }; + (bits_buf, Some(tlv_buf)) + } else { + (buf, None) + }; - if !payload.is_valid() { - return Err(Error::InvalidArgument); + payload_base38_representation_with_tlv(payload, bits_buf, tlv_buf) + } else { + Err(Error::InvalidArgument) } - - payload_base38_representation_with_tlv(payload, &mut bits, tlv_data) } -fn estimate_buffer_size(payload: &QrSetupPayload) -> Result { +pub fn estimate_buffer_size(payload: &QrSetupPayload) -> Result { // Estimate the size of the needed buffer; initialize with the size of the standard payload. let mut estimate = TOTAL_PAYLOAD_DATA_SIZE_IN_BYTES; @@ -298,10 +292,9 @@ fn estimate_buffer_size(payload: &QrSetupPayload) -> Result { size }; - let vendor_data = payload.get_all_optional_data(); - vendor_data.values().for_each(|data| { + for data in payload.get_all_optional_data() { estimate += data_item_size_estimate(&data.data); - }); + } estimate = estimate_struct_overhead(estimate); @@ -372,70 +365,72 @@ fn populate_bits( Ok(()) } -fn payload_base38_representation_with_tlv( +fn payload_base38_representation_with_tlv( payload: &QrSetupPayload, - bits: &mut [u8], - mut tlv_data: Option, -) -> Result { - if let Some(tlv_data) = tlv_data.as_mut() { - generate_tlv_from_optional_data(payload, tlv_data)?; + bits_buf: &mut [u8], + tlv_buf: Option<&mut [u8]>, +) -> Result, Error> { + let tlv_data = if let Some(tlv_buf) = tlv_buf { + Some(generate_tlv_from_optional_data(payload, tlv_buf)?) + } else { + None + }; + + let bits = generate_bit_set(payload, bits_buf, tlv_data)?; + + let mut base38_encoded: heapless::String = "MT:".into(); + + for c in base38::encode(bits) { + base38_encoded.push(c).map_err(|_| Error::NoSpace)?; } - let bytes_written = generate_bit_set(payload, bits, tlv_data)?; - let base38_encoded = base38::encode(&*bits, Some(bytes_written)); - Ok(format!("MT:{}", base38_encoded)) + Ok(base38_encoded) } -fn generate_tlv_from_optional_data( +fn generate_tlv_from_optional_data<'a>( payload: &QrSetupPayload, - tlv_data: &mut TlvData, -) -> Result<(), Error> { - let size_needed = tlv_data.max_data_length_in_bytes as usize; - let mut tlv_buffer = vec![0u8; size_needed]; - let mut wb = WriteBuf::new(&mut tlv_buffer); + tlv_buf: &'a mut [u8], +) -> Result<&'a [u8], Error> { + let mut wb = WriteBuf::new(tlv_buf); let mut tw = TLVWriter::new(&mut wb); tw.start_struct(TagType::Anonymous)?; - let data = payload.get_all_optional_data(); - - for (tag, value) in data { - match &value.data { - QRCodeInfoType::String(data) => tw.utf8(TagType::Context(*tag), data.as_bytes())?, - QRCodeInfoType::Int32(data) => tw.i32(TagType::Context(*tag), *data)?, - QRCodeInfoType::Int64(data) => tw.i64(TagType::Context(*tag), *data)?, - QRCodeInfoType::UInt32(data) => tw.u32(TagType::Context(*tag), *data)?, - QRCodeInfoType::UInt64(data) => tw.u64(TagType::Context(*tag), *data)?, + + for info in payload.get_all_optional_data() { + match &info.data { + QRCodeInfoType::String(data) => tw.utf8(TagType::Context(info.tag), data.as_bytes())?, + QRCodeInfoType::Int32(data) => tw.i32(TagType::Context(info.tag), *data)?, + QRCodeInfoType::Int64(data) => tw.i64(TagType::Context(info.tag), *data)?, + QRCodeInfoType::UInt32(data) => tw.u32(TagType::Context(info.tag), *data)?, + QRCodeInfoType::UInt64(data) => tw.u64(TagType::Context(info.tag), *data)?, } } tw.end_container()?; - tlv_data.data_length_in_bytes = Some(tw.get_tail()); - tlv_data.data = Some(tlv_buffer); - Ok(()) + let tail = tw.get_tail(); + + Ok(&tlv_buf[..tail]) } -fn generate_bit_set( +fn generate_bit_set<'a>( payload: &QrSetupPayload, - bits: &mut [u8], - tlv_data: Option, -) -> Result { - let mut offset: usize = 0; + bits_buf: &'a mut [u8], + tlv_data: Option<&[u8]>, +) -> Result<&'a [u8], Error> { + let total_payload_size_in_bits = + TOTAL_PAYLOAD_DATA_SIZE_IN_BITS + tlv_data.map(|tlv_data| tlv_data.len() * 8).unwrap_or(0); - let total_payload_size_in_bits = if let Some(tlv_data) = &tlv_data { - TOTAL_PAYLOAD_DATA_SIZE_IN_BITS + (tlv_data.data_length_in_bytes.unwrap_or_default() * 8) - } else { - TOTAL_PAYLOAD_DATA_SIZE_IN_BITS - }; - - if bits.len() * 8 < total_payload_size_in_bits { + if bits_buf.len() * 8 < total_payload_size_in_bits { return Err(Error::BufferTooSmall); }; let passwd = passwd_from_comm_data(payload.comm_data); + let mut offset: usize = 0; + populate_bits( - bits, + bits_buf, &mut offset, payload.version as u64, VERSION_FIELD_LENGTH_IN_BITS, @@ -443,7 +438,7 @@ fn generate_bit_set( )?; populate_bits( - bits, + bits_buf, &mut offset, payload.dev_det.vid as u64, VENDOR_IDFIELD_LENGTH_IN_BITS, @@ -451,7 +446,7 @@ fn generate_bit_set( )?; populate_bits( - bits, + bits_buf, &mut offset, payload.dev_det.pid as u64, PRODUCT_IDFIELD_LENGTH_IN_BITS, @@ -459,7 +454,7 @@ fn generate_bit_set( )?; populate_bits( - bits, + bits_buf, &mut offset, payload.flow_type as u64, COMMISSIONING_FLOW_FIELD_LENGTH_IN_BITS, @@ -467,7 +462,7 @@ fn generate_bit_set( )?; populate_bits( - bits, + bits_buf, &mut offset, payload.discovery_capabilities.as_bits() as u64, RENDEZVOUS_INFO_FIELD_LENGTH_IN_BITS, @@ -475,7 +470,7 @@ fn generate_bit_set( )?; populate_bits( - bits, + bits_buf, &mut offset, payload.comm_data.discriminator as u64, PAYLOAD_DISCRIMINATOR_FIELD_LENGTH_IN_BITS, @@ -483,7 +478,7 @@ fn generate_bit_set( )?; populate_bits( - bits, + bits_buf, &mut offset, passwd as u64, SETUP_PINCODE_FIELD_LENGTH_IN_BITS, @@ -491,7 +486,7 @@ fn generate_bit_set( )?; populate_bits( - bits, + bits_buf, &mut offset, 0, PADDING_FIELD_LENGTH_IN_BITS, @@ -499,26 +494,22 @@ fn generate_bit_set( )?; if let Some(tlv_data) = tlv_data { - populate_tlv_bits(bits, &mut offset, tlv_data, total_payload_size_in_bits)?; + populate_tlv_bits(bits_buf, &mut offset, tlv_data, total_payload_size_in_bits)?; } let bytes_written = (offset + 7) / 8; - Ok(bytes_written) + + Ok(&bits_buf[..bytes_written]) } fn populate_tlv_bits( - bits: &mut [u8], + bits_buf: &mut [u8], offset: &mut usize, - tlv_data: TlvData, + tlv_data: &[u8], total_payload_size_in_bits: usize, ) -> Result<(), Error> { - if let (Some(data), Some(data_length_in_bytes)) = (tlv_data.data, tlv_data.data_length_in_bytes) - { - for b in data.iter().take(data_length_in_bytes) { - populate_bits(bits, offset, *b as u64, 8, total_payload_size_in_bits)?; - } - } else { - return Err(Error::InvalidArgument); + for b in tlv_data { + populate_bits(bits_buf, offset, *b as u64, 8, total_payload_size_in_bits)?; } Ok(()) @@ -555,7 +546,9 @@ mod tests { let disc_cap = DiscoveryCapabilities::new(false, true, false); let qr_code_data = QrSetupPayload::new(&dev_det, &comm_data, disc_cap); - let data_str = payload_base38_representation(&qr_code_data).expect("Failed to encode"); + let mut buf = [0; 1024]; + let data_str = payload_base38_representation::<128>(&qr_code_data, &mut buf) + .expect("Failed to encode"); assert_eq!(data_str, QR_CODE) } @@ -576,7 +569,9 @@ mod tests { let disc_cap = DiscoveryCapabilities::new(true, false, false); let qr_code_data = QrSetupPayload::new(&dev_det, &comm_data, disc_cap); - let data_str = payload_base38_representation(&qr_code_data).expect("Failed to encode"); + let mut buf = [0; 1024]; + let data_str = payload_base38_representation::<128>(&qr_code_data, &mut buf) + .expect("Failed to encode"); assert_eq!(data_str, QR_CODE) } @@ -620,7 +615,9 @@ mod tests { ) .expect("Failed to add optional data"); - let data_str = payload_base38_representation(&qr_code_data).expect("Failed to encode"); + let mut buf = [0; 1024]; + let data_str = payload_base38_representation::<{ QR_CODE.len() }>(&qr_code_data, &mut buf) + .expect("Failed to encode"); assert_eq!(data_str, QR_CODE) } } From bcbac965cddac60fa895afad651093e659f207f4 Mon Sep 17 00:00:00 2001 From: ivmarkov Date: Mon, 24 Apr 2023 06:10:58 +0000 Subject: [PATCH 05/72] Remove allocations from Cert handling --- matter/src/cert/asn1_writer.rs | 2 +- matter/src/cert/mod.rs | 148 ++++++++++++++++-------------- matter/src/data_model/sdm/noc.rs | 36 +++++--- matter/src/error.rs | 9 +- matter/src/fabric.rs | 119 +++++++++++++----------- matter/src/secure_channel/case.rs | 16 ++-- matter/src/tlv/traits.rs | 117 ++++++----------------- 7 files changed, 205 insertions(+), 242 deletions(-) diff --git a/matter/src/cert/asn1_writer.rs b/matter/src/cert/asn1_writer.rs index ae2ced83..675546a0 100644 --- a/matter/src/cert/asn1_writer.rs +++ b/matter/src/cert/asn1_writer.rs @@ -17,7 +17,7 @@ use super::{CertConsumer, MAX_DEPTH}; use crate::error::Error; -use chrono::{Datelike, TimeZone, Utc}; +use chrono::{Datelike, TimeZone, Utc}; // TODO use core::fmt::Write; use log::warn; diff --git a/matter/src/cert/mod.rs b/matter/src/cert/mod.rs index b9283299..757a9d66 100644 --- a/matter/src/cert/mod.rs +++ b/matter/src/cert/mod.rs @@ -15,23 +15,22 @@ * limitations under the License. */ -use core::fmt; - -extern crate alloc; +use core::fmt::{self, Write}; use crate::{ crypto::KeyPair, error::Error, - tlv::{self, FromTLV, TLVArrayOwned, TLVElement, TLVWriter, TagType, ToTLV}, + tlv::{self, FromTLV, OctetStr, TLVArray, TLVElement, TLVWriter, TagType, ToTLV}, utils::writebuf::WriteBuf, }; -use alloc::{format, string::String, vec::Vec}; use log::error; use num_derive::FromPrimitive; pub use self::asn1_writer::ASN1Writer; use self::printer::CertPrinter; +pub const MAX_CERT_TLV_LEN: usize = 300; // TODO + // As per https://datatracker.ietf.org/doc/html/rfc5280 const OID_PUB_KEY_ECPUBKEY: [u8; 7] = [0x2A, 0x86, 0x48, 0xCE, 0x3D, 0x02, 0x01]; @@ -116,8 +115,10 @@ macro_rules! add_if { }; } -fn get_print_str(key_usage: u16) -> String { - format!( +fn get_print_str(key_usage: u16) -> heapless::String<256> { + let mut string = heapless::String::new(); + write!( + &mut string, "{}{}{}{}{}{}{}{}{}", add_if!(key_usage, KEY_USAGE_DIGITAL_SIGN, "digitalSignature "), add_if!(key_usage, KEY_USAGE_NON_REPUDIATION, "nonRepudiation "), @@ -129,6 +130,9 @@ fn get_print_str(key_usage: u16) -> String { add_if!(key_usage, KEY_USAGE_ENCIPHER_ONLY, "encipherOnly "), add_if!(key_usage, KEY_USAGE_DECIPHER_ONLY, "decipherOnly "), ) + .unwrap(); + + string } #[allow(unused_assignments)] @@ -140,7 +144,7 @@ fn encode_key_usage(key_usage: u16, w: &mut dyn CertConsumer) -> Result<(), Erro } fn encode_extended_key_usage( - list: &TLVArrayOwned, + list: impl Iterator, w: &mut dyn CertConsumer, ) -> Result<(), Error> { const OID_SERVER_AUTH: [u8; 8] = [0x2B, 0x06, 0x01, 0x05, 0x05, 0x07, 0x03, 0x01]; @@ -160,19 +164,18 @@ fn encode_extended_key_usage( ]; w.start_seq("")?; - for t in list.iter() { - let t = *t as usize; + for t in list { + let t = t as usize; if t > 0 && t <= encoding.len() { w.oid(encoding[t].0, encoding[t].1)?; } else { error!("Skipping encoding key usage out of bounds"); } } - w.end_seq()?; - Ok(()) + w.end_seq() } -#[derive(FromTLV, ToTLV, Default)] +#[derive(FromTLV, ToTLV, Default, Debug)] #[tlvargs(start = 1)] struct BasicConstraints { is_ca: bool, @@ -212,18 +215,18 @@ fn encode_extension_end(w: &mut dyn CertConsumer) -> Result<(), Error> { w.end_seq() } -#[derive(FromTLV, ToTLV, Default)] -#[tlvargs(start = 1, datatype = "list")] -struct Extensions { +#[derive(FromTLV, ToTLV, Default, Debug)] +#[tlvargs(lifetime = "'a", start = 1, datatype = "list")] +struct Extensions<'a> { basic_const: Option, key_usage: Option, - ext_key_usage: Option>, - subj_key_id: Option>, - auth_key_id: Option>, - future_extensions: Option>, + ext_key_usage: Option>, + subj_key_id: Option>, + auth_key_id: Option>, + future_extensions: Option>, } -impl Extensions { +impl<'a> Extensions<'a> { fn encode(&self, w: &mut dyn CertConsumer) -> Result<(), Error> { const OID_BASIC_CONSTRAINTS: [u8; 3] = [0x55, 0x1D, 0x13]; const OID_KEY_USAGE: [u8; 3] = [0x55, 0x1D, 0x0F]; @@ -245,30 +248,29 @@ impl Extensions { } if let Some(t) = &self.ext_key_usage { encode_extension_start("X509v3 Extended Key Usage", true, &OID_EXT_KEY_USAGE, w)?; - encode_extended_key_usage(t, w)?; + encode_extended_key_usage(t.iter(), w)?; encode_extension_end(w)?; } if let Some(t) = &self.subj_key_id { encode_extension_start("Subject Key ID", false, &OID_SUBJ_KEY_IDENTIFIER, w)?; - w.ostr("", t.as_slice())?; + w.ostr("", t.0)?; encode_extension_end(w)?; } if let Some(t) = &self.auth_key_id { encode_extension_start("Auth Key ID", false, &OID_AUTH_KEY_ID, w)?; w.start_seq("")?; - w.ctx("", 0, t.as_slice())?; + w.ctx("", 0, t.0)?; w.end_seq()?; encode_extension_end(w)?; } if let Some(t) = &self.future_extensions { - error!("Future Extensions Not Yet Supported: {:x?}", t.as_slice()) + error!("Future Extensions Not Yet Supported: {:x?}", t.0); } w.end_seq()?; w.end_ctx()?; Ok(()) } } -const MAX_DN_ENTRIES: usize = 5; #[derive(FromPrimitive, Copy, Clone)] enum DnTags { @@ -296,20 +298,23 @@ enum DnTags { NocCat = 22, } -enum DistNameValue { +#[derive(Debug)] +enum DistNameValue<'a> { Uint(u64), - Utf8Str(Vec), - PrintableStr(Vec), + Utf8Str(&'a [u8]), + PrintableStr(&'a [u8]), } -#[derive(Default)] -struct DistNames { +const MAX_DN_ENTRIES: usize = 5; + +#[derive(Default, Debug)] +struct DistNames<'a> { // The order in which the DNs arrive is important, as the signing // requires that the ASN1 notation retains the same order - dn: Vec<(u8, DistNameValue)>, + dn: heapless::Vec<(u8, DistNameValue<'a>), MAX_DN_ENTRIES>, } -impl DistNames { +impl<'a> DistNames<'a> { fn u64(&self, match_id: DnTags) -> Option { self.dn .iter() @@ -339,24 +344,27 @@ impl DistNames { const PRINTABLE_STR_THRESHOLD: u8 = 0x80; -impl<'a> FromTLV<'a> for DistNames { +impl<'a> FromTLV<'a> for DistNames<'a> { fn from_tlv(t: &TLVElement<'a>) -> Result { let mut d = Self { - dn: Vec::with_capacity(MAX_DN_ENTRIES), + dn: heapless::Vec::new(), }; let iter = t.confirm_list()?.enter().ok_or(Error::Invalid)?; for t in iter { if let TagType::Context(tag) = t.get_tag() { if let Ok(value) = t.u64() { - d.dn.push((tag, DistNameValue::Uint(value))); + d.dn.push((tag, DistNameValue::Uint(value))) + .map_err(|_| Error::BufferTooSmall)?; } else if let Ok(value) = t.slice() { if tag > PRINTABLE_STR_THRESHOLD { d.dn.push(( tag - PRINTABLE_STR_THRESHOLD, - DistNameValue::PrintableStr(value.to_vec()), - )); + DistNameValue::PrintableStr(value), + )) + .map_err(|_| Error::BufferTooSmall)?; } else { - d.dn.push((tag, DistNameValue::Utf8Str(value.to_vec()))); + d.dn.push((tag, DistNameValue::Utf8Str(value))) + .map_err(|_| Error::BufferTooSmall)?; } } } @@ -365,24 +373,23 @@ impl<'a> FromTLV<'a> for DistNames { } } -impl ToTLV for DistNames { +impl<'a> ToTLV for DistNames<'a> { fn to_tlv(&self, tw: &mut TLVWriter, tag: TagType) -> Result<(), Error> { tw.start_list(tag)?; for (name, value) in &self.dn { match value { DistNameValue::Uint(v) => tw.u64(TagType::Context(*name), *v)?, - DistNameValue::Utf8Str(v) => tw.utf8(TagType::Context(*name), v.as_slice())?, - DistNameValue::PrintableStr(v) => tw.utf8( - TagType::Context(*name + PRINTABLE_STR_THRESHOLD), - v.as_slice(), - )?, + DistNameValue::Utf8Str(v) => tw.utf8(TagType::Context(*name), v)?, + DistNameValue::PrintableStr(v) => { + tw.utf8(TagType::Context(*name + PRINTABLE_STR_THRESHOLD), v)? + } } } tw.end_container() } } -impl DistNames { +impl<'a> DistNames<'a> { fn encode(&self, tag: &str, w: &mut dyn CertConsumer) -> Result<(), Error> { const OID_COMMON_NAME: [u8; 3] = [0x55_u8, 0x04, 0x03]; const OID_SURNAME: [u8; 3] = [0x55_u8, 0x04, 0x04]; @@ -520,38 +527,36 @@ fn encode_dn_value( } }, DistNameValue::Utf8Str(v) => { - let str = String::from_utf8(v.to_vec())?; - w.utf8str("", &str)?; + w.utf8str("", core::str::from_utf8(v)?)?; } DistNameValue::PrintableStr(v) => { - let str = String::from_utf8(v.to_vec())?; - w.printstr("", &str)?; + w.printstr("", core::str::from_utf8(v)?)?; } } w.end_seq()?; w.end_set() } -#[derive(FromTLV, ToTLV, Default)] -#[tlvargs(start = 1)] -pub struct Cert { - serial_no: Vec, +#[derive(FromTLV, ToTLV, Default, Debug)] +#[tlvargs(lifetime = "'a", start = 1)] +pub struct Cert<'a> { + serial_no: OctetStr<'a>, sign_algo: u8, - issuer: DistNames, + issuer: DistNames<'a>, not_before: u32, not_after: u32, - subject: DistNames, + subject: DistNames<'a>, pubkey_algo: u8, ec_curve_id: u8, - pubkey: Vec, - extensions: Extensions, - signature: Vec, + pubkey: OctetStr<'a>, + extensions: Extensions<'a>, + signature: OctetStr<'a>, } // TODO: Instead of parsing the TLVs everytime, we should just cache this, but the encoding // rules in terms of sequence may get complicated. Need to look into this -impl Cert { - pub fn new(cert_bin: &[u8]) -> Result { +impl<'a> Cert<'a> { + pub fn new(cert_bin: &'a [u8]) -> Result { let root = tlv::get_root_node(cert_bin)?; Cert::from_tlv(&root) } @@ -569,17 +574,21 @@ impl Cert { } pub fn get_pubkey(&self) -> &[u8] { - self.pubkey.as_slice() + self.pubkey.0 } pub fn get_subject_key_id(&self) -> Result<&[u8], Error> { - self.extensions.subj_key_id.as_deref().ok_or(Error::Invalid) + if let Some(id) = self.extensions.subj_key_id.as_ref() { + Ok(id.0) + } else { + Err(Error::Invalid) + } } pub fn is_authority(&self, their: &Cert) -> Result { if let Some(our_auth_key) = &self.extensions.auth_key_id { let their_subject = their.get_subject_key_id()?; - if our_auth_key == their_subject { + if our_auth_key.0 == their_subject { Ok(true) } else { Ok(false) @@ -590,7 +599,7 @@ impl Cert { } pub fn get_signature(&self) -> &[u8] { - self.signature.as_slice() + self.signature.0 } pub fn as_tlv(&self, buf: &mut [u8]) -> Result { @@ -617,7 +626,7 @@ impl Cert { w.integer("", &[2])?; w.end_ctx()?; - w.integer("Serial Num:", self.serial_no.as_slice())?; + w.integer("Serial Num:", self.serial_no.0)?; w.start_seq("Signature Algorithm:")?; let (str, oid) = match get_sign_algo(self.sign_algo).ok_or(Error::Invalid)? { @@ -647,7 +656,7 @@ impl Cert { w.oid(str, &curve_id)?; w.end_seq()?; - w.bitstr("Public-Key:", false, self.pubkey.as_slice())?; + w.bitstr("Public-Key:", false, self.pubkey.0)?; w.end_seq()?; self.extensions.encode(w)?; @@ -658,7 +667,7 @@ impl Cert { } } -impl fmt::Display for Cert { +impl<'a> fmt::Display for Cert<'a> { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { let mut printer = CertPrinter::new(f); let _ = self @@ -670,7 +679,7 @@ impl fmt::Display for Cert { } pub struct CertVerifier<'a> { - cert: &'a Cert, + cert: &'a Cert<'a>, } impl<'a> CertVerifier<'a> { @@ -809,6 +818,7 @@ mod tests { #[test] fn test_tlv_conversions() { + let _ = env_logger::try_init(); let test_input: [&[u8]; 3] = [ &test_vectors::NOC1_SUCCESS, &test_vectors::ICAC1_SUCCESS, diff --git a/matter/src/data_model/sdm/noc.rs b/matter/src/data_model/sdm/noc.rs index acaea504..b2dcee21 100644 --- a/matter/src/data_model/sdm/noc.rs +++ b/matter/src/data_model/sdm/noc.rs @@ -19,7 +19,7 @@ use core::cell::RefCell; use core::convert::TryInto; use crate::acl::{AclEntry, AclMgr, AuthMode}; -use crate::cert::Cert; +use crate::cert::{Cert, MAX_CERT_TLV_LEN}; use crate::crypto::{self, KeyPair}; use crate::data_model::objects::*; use crate::data_model::sdm::dev_att; @@ -158,14 +158,14 @@ pub const CLUSTER: Cluster<'static> = Cluster { pub struct NocData { pub key_pair: KeyPair, - pub root_ca: Cert, + pub root_ca: heapless::Vec, } impl NocData { pub fn new(key_pair: KeyPair) -> Self { Self { key_pair, - root_ca: Cert::default(), + root_ca: heapless::Vec::new(), } } } @@ -259,8 +259,10 @@ impl<'a> NocCluster<'a> { writer.start_array(AttrDataWriter::TAG)?; self.fabric_mgr.borrow().for_each(|entry, fab_idx| { if !attr.fab_filter || attr.fab_idx == fab_idx { + let root_ca_cert = entry.get_root_ca()?; + entry - .get_fabric_desc(fab_idx) + .get_fabric_desc(fab_idx, &root_ca_cert)? .to_tlv(&mut writer, TagType::Anonymous)?; } @@ -351,12 +353,18 @@ impl<'a> NocCluster<'a> { let r = AddNocReq::from_tlv(data).map_err(|_| NocStatus::InvalidNOC)?; - let noc_value = Cert::new(r.noc_value.0).map_err(|_| NocStatus::InvalidNOC)?; - info!("Received NOC as: {}", noc_value); - let icac_value = if !r.icac_value.0.is_empty() { - let cert = Cert::new(r.icac_value.0).map_err(|_| NocStatus::InvalidNOC)?; - info!("Received ICAC as: {}", cert); - Some(cert) + let noc_cert = Cert::new(r.noc_value.0).map_err(|_| NocStatus::InvalidNOC)?; + info!("Received NOC as: {}", noc_cert); + + let noc = heapless::Vec::from_slice(r.noc_value.0).map_err(|_| NocStatus::InvalidNOC)?; + + let icac = if !r.icac_value.0.is_empty() { + let icac_cert = Cert::new(r.icac_value.0).map_err(|_| NocStatus::InvalidNOC)?; + info!("Received ICAC as: {}", icac_cert); + + let icac = + heapless::Vec::from_slice(r.icac_value.0).map_err(|_| NocStatus::InvalidNOC)?; + Some(icac) } else { None }; @@ -364,8 +372,8 @@ impl<'a> NocCluster<'a> { let fabric = Fabric::new( noc_data.key_pair, noc_data.root_ca, - icac_value, - noc_value, + icac, + noc, r.ipk_value.0, r.vendor_id, "", @@ -592,7 +600,9 @@ impl<'a> NocCluster<'a> { let req = CommonReq::from_tlv(data).map_err(Error::map_invalid_command)?; info!("Received Trusted Cert:{:x?}", req.str); - noc_data.root_ca = Cert::new(req.str.0)?; + noc_data.root_ca = + heapless::Vec::from_slice(req.str.0).map_err(|_| Error::BufferTooSmall)?; + // TODO } _ => (), } diff --git a/matter/src/error.rs b/matter/src/error.rs index 3a54b2c7..22a04e4c 100644 --- a/matter/src/error.rs +++ b/matter/src/error.rs @@ -15,14 +15,11 @@ * limitations under the License. */ -use alloc::string::FromUtf8Error; -use core::{array::TryFromSliceError, fmt}; +use core::{array::TryFromSliceError, fmt, str::Utf8Error}; use async_channel::{SendError, TryRecvError}; use log::error; -extern crate alloc; - #[derive(Debug, PartialEq, Clone, Copy)] pub enum Error { AttributeNotFound, @@ -166,8 +163,8 @@ impl From> for Error { } } -impl From for Error { - fn from(_e: FromUtf8Error) -> Self { +impl From for Error { + fn from(_e: Utf8Error) -> Self { Self::Utf8Fail } } diff --git a/matter/src/fabric.rs b/matter/src/fabric.rs index b7e2425a..6c9d389c 100644 --- a/matter/src/fabric.rs +++ b/matter/src/fabric.rs @@ -21,7 +21,7 @@ use byteorder::{BigEndian, ByteOrder, LittleEndian}; use log::{error, info}; use crate::{ - cert::Cert, + cert::{Cert, MAX_CERT_TLV_LEN}, crypto::{self, hkdf_sha256, HmacSha256, KeyPair}, error::Error, group_keys::KeySet, @@ -30,7 +30,6 @@ use crate::{ tlv::{OctetStr, TLVWriter, TagType, ToTLV, UtfStr}, }; -const MAX_CERT_TLV_LEN: usize = 300; const COMPRESSED_FABRIC_ID_LEN: usize = 8; macro_rules! fb_key { @@ -72,9 +71,9 @@ pub struct Fabric { fabric_id: u64, vendor_id: u16, key_pair: KeyPair, - pub root_ca: Cert, - pub icac: Option, - pub noc: Cert, + pub root_ca: heapless::Vec, + pub icac: Option>, + pub noc: heapless::Vec, pub ipk: KeySet, label: heapless::String<32>, mdns_service_name: heapless::String<33>, @@ -83,20 +82,25 @@ pub struct Fabric { impl Fabric { pub fn new( key_pair: KeyPair, - root_ca: Cert, - icac: Option, - noc: Cert, + root_ca: heapless::Vec, + icac: Option>, + noc: heapless::Vec, ipk: &[u8], vendor_id: u16, label: &str, ) -> Result { - let node_id = noc.get_node_id()?; - let fabric_id = noc.get_fabric_id()?; + let (node_id, fabric_id) = { + let noc_p = Cert::new(&noc)?; + (noc_p.get_node_id()?, noc_p.get_fabric_id()?) + }; let mut compressed_id = [0_u8; COMPRESSED_FABRIC_ID_LEN]; - Fabric::get_compressed_id(root_ca.get_pubkey(), fabric_id, &mut compressed_id)?; - let ipk = KeySet::new(ipk, &compressed_id)?; + let ipk = { + let root_ca_p = Cert::new(&root_ca)?; + Fabric::get_compressed_id(root_ca_p.get_pubkey(), fabric_id, &mut compressed_id)?; + KeySet::new(ipk, &compressed_id)? + }; let mut mdns_service_name = heapless::String::<33>::new(); for c in compressed_id { @@ -144,7 +148,7 @@ impl Fabric { let mut mac = HmacSha256::new(self.ipk.op_key())?; mac.update(random)?; - mac.update(self.root_ca.get_pubkey())?; + mac.update(self.get_root_ca()?.get_pubkey())?; let mut buf: [u8; 8] = [0; 8]; LittleEndian::write_u64(&mut buf, self.fabric_id); @@ -174,15 +178,25 @@ impl Fabric { self.fabric_id } - pub fn get_fabric_desc(&self, fab_idx: u8) -> FabricDescriptor { - FabricDescriptor { - root_public_key: OctetStr::new(self.root_ca.get_pubkey()), + pub fn get_root_ca(&self) -> Result, Error> { + Cert::new(&self.root_ca) + } + + pub fn get_fabric_desc<'a>( + &'a self, + fab_idx: u8, + root_ca_cert: &'a Cert, + ) -> Result, Error> { + let desc = FabricDescriptor { + root_public_key: OctetStr::new(root_ca_cert.get_pubkey()), vendor_id: self.vendor_id, fabric_id: self.fabric_id, node_id: self.node_id, label: UtfStr(self.label.as_bytes()), fab_idx: Some(fab_idx), - } + }; + + Ok(desc) } fn store(&self, index: usize, mut psm: T) -> Result<(), Error> @@ -191,19 +205,13 @@ impl Fabric { { let mut _kb = heapless::String::<32>::new(); - let mut buf = [0u8; MAX_CERT_TLV_LEN]; - let len = self.root_ca.as_tlv(&mut buf)?; - psm.set_kv_slice(fb_key!(index, ST_RCA, _kb), &buf[..len])?; - - let len = if let Some(icac) = &self.icac { - icac.as_tlv(&mut buf)? - } else { - 0 - }; - psm.set_kv_slice(fb_key!(index, ST_ICA, _kb), &buf[..len])?; + psm.set_kv_slice(fb_key!(index, ST_RCA, _kb), &self.root_ca)?; + psm.set_kv_slice( + fb_key!(index, ST_ICA, _kb), + self.icac.as_deref().unwrap_or(&[]), + )?; - let len = self.noc.as_tlv(&mut buf)?; - psm.set_kv_slice(fb_key!(index, ST_NOC, _kb), &buf[..len])?; + psm.set_kv_slice(fb_key!(index, ST_NOC, _kb), &self.noc)?; psm.set_kv_slice(fb_key!(index, ST_IPK, _kb), self.ipk.epoch_key())?; psm.set_kv_slice(fb_key!(index, ST_LBL, _kb), self.label.as_bytes())?; @@ -228,18 +236,21 @@ impl Fabric { let mut _kb = heapless::String::<32>::new(); let mut buf = [0u8; MAX_CERT_TLV_LEN]; - let root_ca = psm.get_kv_slice(fb_key!(index, ST_RCA, _kb), &mut buf)?; - let root_ca = Cert::new(root_ca)?; + + let root_ca = + heapless::Vec::from_slice(psm.get_kv_slice(fb_key!(index, ST_RCA, _kb), &mut buf)?) + .unwrap(); let icac = psm.get_kv_slice(fb_key!(index, ST_ICA, _kb), &mut buf)?; let icac = if !icac.is_empty() { - Some(Cert::new(icac)?) + Some(heapless::Vec::from_slice(icac).unwrap()) } else { None }; - let noc = psm.get_kv_slice(fb_key!(index, ST_NOC, _kb), &mut buf)?; - let noc = Cert::new(noc)?; + let noc = + heapless::Vec::from_slice(psm.get_kv_slice(fb_key!(index, ST_NOC, _kb), &mut buf)?) + .unwrap(); let label = psm.get_kv_slice(fb_key!(index, ST_LBL, _kb), &mut buf)?; let label: heapless::String<32> = core::str::from_utf8(label) @@ -293,21 +304,16 @@ impl Fabric { { let mut _kb = heapless::String::<32>::new(); - let mut buf = [0u8; MAX_CERT_TLV_LEN]; - let len = self.root_ca.as_tlv(&mut buf)?; - psm.set_kv_slice(fb_key!(index, ST_RCA, _kb), &buf[..len]) + psm.set_kv_slice(fb_key!(index, ST_RCA, _kb), &self.root_ca) .await?; - let len = if let Some(icac) = &self.icac { - icac.as_tlv(&mut buf)? - } else { - 0 - }; - psm.set_kv_slice(fb_key!(index, ST_ICA, _kb), &buf[..len]) - .await?; + psm.set_kv_slice( + fb_key!(index, ST_ICA, _kb), + self.icac.as_deref().unwrap_or(&[]), + ) + .await?; - let len = self.noc.as_tlv(&mut buf)?; - psm.set_kv_slice(fb_key!(index, ST_NOC, _kb), &buf[..len]) + psm.set_kv_slice(fb_key!(index, ST_NOC, _kb), &self.noc) .await?; psm.set_kv_slice(fb_key!(index, ST_IPK, _kb), self.ipk.epoch_key()) .await?; @@ -337,24 +343,27 @@ impl Fabric { let mut _kb = heapless::String::<32>::new(); let mut buf = [0u8; MAX_CERT_TLV_LEN]; - let root_ca = psm - .get_kv_slice(fb_key!(index, ST_RCA, _kb), &mut buf) - .await?; - let root_ca = Cert::new(root_ca)?; + + let root_ca = heapless::Vec::from_slice( + psm.get_kv_slice(fb_key!(index, ST_RCA, _kb), &mut buf) + .await?, + ) + .unwrap(); let icac = psm .get_kv_slice(fb_key!(index, ST_ICA, _kb), &mut buf) .await?; let icac = if !icac.is_empty() { - Some(Cert::new(icac)?) + Some(heapless::Vec::from_slice(icac).unwrap()) } else { None }; - let noc = psm - .get_kv_slice(fb_key!(index, ST_NOC, _kb), &mut buf) - .await?; - let noc = Cert::new(noc)?; + let noc = heapless::Vec::from_slice( + psm.get_kv_slice(fb_key!(index, ST_NOC, _kb), &mut buf) + .await?, + ) + .unwrap(); let label = psm .get_kv_slice(fb_key!(index, ST_LBL, _kb), &mut buf) diff --git a/matter/src/secure_channel/case.rs b/matter/src/secure_channel/case.rs index f5b9cb04..a722dae8 100644 --- a/matter/src/secure_channel/case.rs +++ b/matter/src/secure_channel/case.rs @@ -349,7 +349,9 @@ impl<'a> Case<'a> { verifier = verifier.add_cert(icac)?; } - verifier.add_cert(&fabric.root_ca)?.finalise()?; + verifier + .add_cert(&Cert::new(&fabric.root_ca)?)? + .finalise()?; Ok(()) } @@ -481,9 +483,9 @@ impl<'a> Case<'a> { let mut write_buf = WriteBuf::new(out); let mut tw = TLVWriter::new(&mut write_buf); tw.start_struct(TagType::Anonymous)?; - tw.str16_as(TagType::Context(1), |buf| fabric.noc.as_tlv(buf))?; - if let Some(icac_cert) = &fabric.icac { - tw.str16_as(TagType::Context(2), |buf| icac_cert.as_tlv(buf))? + tw.str16(TagType::Context(1), &fabric.noc)?; + if let Some(icac_cert) = fabric.icac.as_ref() { + tw.str16(TagType::Context(2), icac_cert)? }; tw.str8(TagType::Context(3), signature)?; @@ -523,9 +525,9 @@ impl<'a> Case<'a> { let mut write_buf = WriteBuf::new(&mut buf); let mut tw = TLVWriter::new(&mut write_buf); tw.start_struct(TagType::Anonymous)?; - tw.str16_as(TagType::Context(1), |buf| fabric.noc.as_tlv(buf))?; - if let Some(icac_cert) = &fabric.icac { - tw.str16_as(TagType::Context(2), |buf| icac_cert.as_tlv(buf))?; + tw.str16(TagType::Context(1), &fabric.noc)?; + if let Some(icac_cert) = fabric.icac.as_deref() { + tw.str16(TagType::Context(2), icac_cert)?; } tw.str8(TagType::Context(3), our_pub_key)?; tw.str8(TagType::Context(4), peer_pub_key)?; diff --git a/matter/src/tlv/traits.rs b/matter/src/tlv/traits.rs index c7b5e359..72cfab23 100644 --- a/matter/src/tlv/traits.rs +++ b/matter/src/tlv/traits.rs @@ -17,14 +17,10 @@ use super::{ElementType, TLVContainerIterator, TLVElement, TLVWriter, TagType}; use crate::error::Error; -use alloc::borrow::ToOwned; -use alloc::{string::String, vec::Vec}; use core::fmt::Debug; use core::slice::Iter; use log::error; -extern crate alloc; - pub trait FromTLV<'a> { fn from_tlv(t: &TLVElement<'a>) -> Result where @@ -118,14 +114,11 @@ totlv_for!(i8 u8 u16 u32 u64 bool); // // - UtfStr, OctetStr: These are versions that map to utfstr and ostr in the TLV spec // - These only have references into the original list -// - String, Vec: Is the owned version of utfstr and ostr, data is cloned into this -// - String is only partially implemented // // - TLVArray: Is an array of entries, with reference within the original list -// - TLVArrayOwned: Is the owned version of this, data is cloned into this /// Implements UTFString from the spec -#[derive(Debug, Copy, Clone, PartialEq)] +#[derive(Debug, Copy, Clone, PartialEq, Default)] pub struct UtfStr<'a>(pub &'a [u8]); impl<'a> UtfStr<'a> { @@ -136,10 +129,6 @@ impl<'a> UtfStr<'a> { pub fn as_str(&self) -> Result<&str, Error> { core::str::from_utf8(self.0).map_err(|_| Error::Invalid) } - - pub fn to_string(self) -> Result { - String::from_utf8(self.0.to_vec()).map_err(|_| Error::Invalid) - } } impl<'a> ToTLV for UtfStr<'a> { @@ -155,7 +144,7 @@ impl<'a> FromTLV<'a> for UtfStr<'a> { } /// Implements OctetString from the spec -#[derive(Debug, Copy, Clone, PartialEq)] +#[derive(Debug, Copy, Clone, PartialEq, Default)] pub struct OctetStr<'a>(pub &'a [u8]); impl<'a> OctetStr<'a> { @@ -176,41 +165,6 @@ impl<'a> ToTLV for OctetStr<'a> { } } -/// Implements the Owned version of Octet String -impl FromTLV<'_> for Vec { - fn from_tlv(t: &TLVElement) -> Result, Error> { - t.slice().map(|x| x.to_owned()) - } -} - -impl ToTLV for Vec { - fn to_tlv(&self, tw: &mut TLVWriter, tag: TagType) -> Result<(), Error> { - tw.str16(tag, self.as_slice()) - } -} - -/// Implements the Owned version of UTF String -impl FromTLV<'_> for String { - fn from_tlv(t: &TLVElement) -> Result { - match t.slice() { - Ok(x) => { - if let Ok(s) = String::from_utf8(x.to_vec()) { - Ok(s) - } else { - Err(Error::Invalid) - } - } - Err(e) => Err(e), - } - } -} - -impl ToTLV for String { - fn to_tlv(&self, tw: &mut TLVWriter, tag: TagType) -> Result<(), Error> { - tw.utf16(tag, self.as_bytes()) - } -} - /// Applies to all the Option<> Processing impl<'a, T: FromTLV<'a>> FromTLV<'a> for Option { fn from_tlv(t: &TLVElement<'a>) -> Result, Error> { @@ -279,37 +233,6 @@ impl ToTLV for Nullable { } } -/// Owned version of a TLVArray -pub struct TLVArrayOwned(Vec); -impl<'a, T: FromTLV<'a>> FromTLV<'a> for TLVArrayOwned { - fn from_tlv(t: &TLVElement<'a>) -> Result { - t.confirm_array()?; - let mut vec = Vec::::new(); - if let Some(tlv_iter) = t.enter() { - for element in tlv_iter { - vec.push(T::from_tlv(&element)?); - } - } - Ok(Self(vec)) - } -} - -impl ToTLV for TLVArrayOwned { - fn to_tlv(&self, tw: &mut TLVWriter, tag_type: TagType) -> Result<(), Error> { - tw.start_array(tag_type)?; - for t in &self.0 { - t.to_tlv(tw, TagType::Anonymous)?; - } - tw.end_container() - } -} - -impl TLVArrayOwned { - pub fn iter(&self) -> Iter { - self.0.iter() - } -} - #[derive(Copy, Clone)] pub enum TLVArray<'a, T> { // This is used for the to-tlv path @@ -390,18 +313,23 @@ where } } -impl<'a, T: ToTLV> ToTLV for TLVArray<'a, T> { +impl<'a, T: FromTLV<'a> + Copy + ToTLV> ToTLV for TLVArray<'a, T> { fn to_tlv(&self, tw: &mut TLVWriter, tag_type: TagType) -> Result<(), Error> { - match *self { - Self::Slice(s) => { - tw.start_array(tag_type)?; - for a in s { - a.to_tlv(tw, TagType::Anonymous)?; - } - tw.end_container() - } - Self::Ptr(t) => t.to_tlv(tw, tag_type), + tw.start_array(tag_type)?; + for a in self.iter() { + a.to_tlv(tw, TagType::Anonymous)?; } + tw.end_container() + // match *self { + // Self::Slice(s) => { + // tw.start_array(tag_type)?; + // for a in s { + // a.to_tlv(tw, TagType::Anonymous)?; + // } + // tw.end_container() + // } + // Self::Ptr(t) => t.to_tlv(tw, tag_type), <-- TODO: this fails the unit tests of Cert from/to TLV + // } } } @@ -414,10 +342,17 @@ impl<'a, T> FromTLV<'a> for TLVArray<'a, T> { impl<'a, T: Debug + ToTLV + FromTLV<'a> + Copy> Debug for TLVArray<'a, T> { fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { + write!(f, "TLVArray [")?; + let mut first = true; for i in self.iter() { - writeln!(f, "{:?}", i)?; + if !first { + write!(f, ", ")?; + } + + write!(f, "{:?}", i)?; + first = false; } - writeln!(f) + write!(f, "]") } } From b4b549bb105a6f6e38815e6451f4ea1949002bc6 Mon Sep 17 00:00:00 2001 From: ivmarkov Date: Mon, 24 Apr 2023 07:20:49 +0000 Subject: [PATCH 06/72] Fix several no_std incompatibilities --- matter/Cargo.toml | 7 ++++--- matter/src/acl.rs | 2 +- matter/src/cert/mod.rs | 12 ++++++++++-- matter/src/error.rs | 1 + matter/src/transport/dedup.rs | 10 ++++++---- matter/src/transport/exchange.rs | 14 +++++++------- 6 files changed, 29 insertions(+), 17 deletions(-) diff --git a/matter/Cargo.toml b/matter/Cargo.toml index 0987e10d..b6260838 100644 --- a/matter/Cargo.toml +++ b/matter/Cargo.toml @@ -16,10 +16,11 @@ path = "src/lib.rs" [features] default = ["std", "crypto_mbedtls"] -std = [] +std = ["alloc"] +alloc = [] nightly = [] crypto_openssl = ["openssl", "foreign-types", "hmac", "sha2"] -crypto_mbedtls = ["mbedtls"] +crypto_mbedtls = ["mbedtls", "alloc"] crypto_esp_mbedtls = ["esp-idf-sys"] crypto_rustcrypto = ["sha2", "hmac", "pbkdf2", "hkdf", "aes", "ccm", "p256", "elliptic-curve", "crypto-bigint", "x509-cert"] @@ -33,6 +34,7 @@ generic-array = "0.14.6" num = "0.4" num-derive = "0.3.3" num-traits = "0.2.15" +strum = { version = "0.24", features = ["derive"], default-features = false, no-default-feature = true } log = { version = "0.4.17", features = ["max_level_debug", "release_max_level_debug"] } env_logger = { version = "0.10.0", default-features = false, features = [] } rand = "0.8.5" @@ -44,7 +46,6 @@ owning_ref = "0.4.1" safemem = "0.3.3" chrono = { version = "0.4.23", default-features = false, features = ["clock", "std"] } async-channel = "1.8" -strum = { version = "0.24", features = ["derive"], no-default-feature = true } # crypto openssl = { git = "https://github.com/sfackler/rust-openssl", optional = true } diff --git a/matter/src/acl.rs b/matter/src/acl.rs index 2bfd0d61..d73ce47f 100644 --- a/matter/src/acl.rs +++ b/matter/src/acl.rs @@ -143,7 +143,7 @@ impl AccessorSubjects { } impl Display for AccessorSubjects { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::result::Result<(), std::fmt::Error> { + fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::result::Result<(), core::fmt::Error> { write!(f, "[")?; for i in self.0 { if is_noc_cat(i) { diff --git a/matter/src/cert/mod.rs b/matter/src/cert/mod.rs index 757a9d66..918737af 100644 --- a/matter/src/cert/mod.rs +++ b/matter/src/cert/mod.rs @@ -519,8 +519,16 @@ fn encode_dn_value( w.oid(name, oid)?; match value { DistNameValue::Uint(v) => match expected_len { - Some(IntToStringLen::Len16) => w.utf8str("", format!("{:016X}", v).as_str())?, - Some(IntToStringLen::Len8) => w.utf8str("", format!("{:08X}", v).as_str())?, + Some(IntToStringLen::Len16) => { + let mut string = heapless::String::<32>::new(); + write!(&mut string, "{:016X}", v).unwrap(); + w.utf8str("", &string)? + } + Some(IntToStringLen::Len8) => { + let mut string = heapless::String::<32>::new(); + write!(&mut string, "{:08X}", v).unwrap(); + w.utf8str("", &string)? + } _ => { error!("Invalid encoding"); return Err(Error::Invalid); diff --git a/matter/src/error.rs b/matter/src/error.rs index 22a04e4c..04d55b3b 100644 --- a/matter/src/error.rs +++ b/matter/src/error.rs @@ -182,4 +182,5 @@ impl fmt::Display for Error { } } +#[cfg(feature = "std")] impl std::error::Error for Error {} diff --git a/matter/src/transport/dedup.rs b/matter/src/transport/dedup.rs index 981bda74..f2c382e8 100644 --- a/matter/src/transport/dedup.rs +++ b/matter/src/transport/dedup.rs @@ -86,6 +86,8 @@ impl RxCtrState { #[cfg(test)] mod tests { + use log::info; + use super::RxCtrState; const ENCRYPTED: bool = true; @@ -194,10 +196,10 @@ mod tests { #[test] fn unencrypted_device_reboot() { - println!("Sub 65532 is {:?}", 1_u16.overflowing_sub(65532)); - println!("Sub 65535 is {:?}", 1_u16.overflowing_sub(65535)); - println!("Sub 11-13 is {:?}", 11_u32.wrapping_sub(13_u32) as i32); - println!("Sub regular is {:?}", 2000_u16.overflowing_sub(1998)); + info!("Sub 65532 is {:?}", 1_u16.overflowing_sub(65532)); + info!("Sub 65535 is {:?}", 1_u16.overflowing_sub(65535)); + info!("Sub 11-13 is {:?}", 11_u32.wrapping_sub(13_u32) as i32); + info!("Sub regular is {:?}", 2000_u16.overflowing_sub(1998)); let mut s = RxCtrState::new(20010); assert_ndup(s.recv(20011, NOT_ENCRYPTED)); diff --git a/matter/src/transport/exchange.rs b/matter/src/transport/exchange.rs index c28a5b2e..053bf793 100644 --- a/matter/src/transport/exchange.rs +++ b/matter/src/transport/exchange.rs @@ -457,13 +457,9 @@ mod tests { error::Error, transport::{ network::Address, - packet::{Packet, MAX_TX_BUF_SIZE}, - session::{CloneData, SessionMode, MAX_SESSIONS}, - }, - utils::{ - epoch::{dummy_epoch, sys_epoch}, - rand::dummy_rand, + session::{CloneData, SessionMode}, }, + utils::{epoch::dummy_epoch, rand::dummy_rand}, }; use super::{ExchangeMgr, Role}; @@ -526,13 +522,17 @@ mod tests { } } + #[cfg(feature = "std")] #[test] /// We purposefuly overflow the sessions /// and when the overflow happens, we confirm that /// - The sessions are evicted in LRU /// - The exchanges associated with those sessions are evicted too fn test_sess_evict() { - let mut mgr = ExchangeMgr::new(sys_epoch, dummy_rand); + use crate::transport::packet::{Packet, MAX_TX_BUF_SIZE}; + use crate::transport::session::MAX_SESSIONS; + + let mut mgr = ExchangeMgr::new(crate::utils::epoch::sys_epoch, dummy_rand); fill_sessions(&mut mgr, MAX_SESSIONS + 1); // Sessions are now full from local session id 1 to 16 From bd87ac4ab390e5f3b76196c387411912b8145759 Mon Sep 17 00:00:00 2001 From: ivmarkov Date: Mon, 24 Apr 2023 07:21:35 +0000 Subject: [PATCH 07/72] Linux & MacOS mDNS services now implement the Mdns trait --- matter/src/sys/mod.rs | 8 +-- matter/src/sys/sys_linux.rs | 114 +++++++++++++++++++++++++++--------- matter/src/sys/sys_macos.rs | 107 ++++++++++++++++++++++++++------- 3 files changed, 175 insertions(+), 54 deletions(-) diff --git a/matter/src/sys/mod.rs b/matter/src/sys/mod.rs index 9b5219ef..0ce65e71 100644 --- a/matter/src/sys/mod.rs +++ b/matter/src/sys/mod.rs @@ -15,14 +15,14 @@ * limitations under the License. */ -#[cfg(target_os = "macos")] +#[cfg(all(feature = "std", target_os = "macos"))] mod sys_macos; -#[cfg(target_os = "macos")] +#[cfg(all(feature = "std", target_os = "macos"))] pub use self::sys_macos::*; -#[cfg(target_os = "linux")] +#[cfg(all(feature = "std", target_os = "linux"))] mod sys_linux; -#[cfg(target_os = "linux")] +#[cfg(all(feature = "std", target_os = "linux"))] pub use self::sys_linux::*; pub const SPAKE2_ITERATION_COUNT: u32 = 2000; diff --git a/matter/src/sys/sys_linux.rs b/matter/src/sys/sys_linux.rs index 881764df..0d3f0dc2 100644 --- a/matter/src/sys/sys_linux.rs +++ b/matter/src/sys/sys_linux.rs @@ -16,43 +16,101 @@ */ use crate::error::Error; -use lazy_static::lazy_static; +use crate::mdns::Mdns; use libmdns::{Responder, Service}; use log::info; -use std::sync::{Arc, Mutex}; +use std::collections::HashMap; use std::vec::Vec; -#[allow(dead_code)] -pub struct SysMdnsService { - service: Service, +#[derive(Debug, Clone, Eq, PartialEq, Hash)] +pub struct ServiceId { + name: String, + service_type: String, + port: u16, } -lazy_static! { - static ref RESPONDER: Arc> = Arc::new(Mutex::new(Responder::new().unwrap())); +pub struct LinuxMdns { + responder: Responder, + services: HashMap, } -/// Publish a mDNS service -/// name - can be a service name (comma separate subtypes may follow) -/// regtype - registration type (e.g. _matter_.tcp etc) -/// port - the port -pub fn sys_publish_service( - name: &str, - regtype: &str, - port: u16, - txt_kvs: &[[&str; 2]], -) -> Result { - info!("mDNS Registration Type {}", regtype); - info!("mDNS properties {:?}", txt_kvs); - - let mut properties = Vec::new(); - for kvs in txt_kvs { - info!("mDNS TXT key {} val {}", kvs[0], kvs[1]); - properties.push(format!("{}={}", kvs[0], kvs[1])); +impl LinuxMdns { + pub fn new() -> Result { + let responder = Responder::new()?; + + Ok(Self { + responder, + services: HashMap::new(), + }) } - let properties: Vec<&str> = properties.iter().map(|entry| entry.as_str()).collect(); - let responder = RESPONDER.lock().map_err(|_| Error::MdnsError)?; - let service = responder.register(regtype.to_owned(), name.to_owned(), port, &properties); + pub fn add( + &mut self, + name: &str, + service_type: &str, + port: u16, + txt_kvs: &[(&str, &str)], + ) -> Result<(), Error> { + info!( + "Registering mDNS service {}/{}/{}", + name, service_type, port + ); + + let _ = self.remove(name, service_type, port); + + let mut properties = Vec::new(); + for kvs in txt_kvs { + info!("mDNS TXT key {} val {}", kvs.0, kvs.1); + properties.push(format!("{}={}", kvs.0, kvs.1)); + } + let properties: Vec<&str> = properties.iter().map(|entry| entry.as_str()).collect(); + + let service = + self.responder + .register(service_type.to_owned(), name.to_owned(), port, &properties); + + self.services.insert( + ServiceId { + name: name.into(), + service_type: service_type.into(), + port, + }, + service, + ); + + Ok(()) + } - Ok(SysMdnsService { service }) + pub fn remove(&mut self, name: &str, service_type: &str, port: u16) -> Result<(), Error> { + let id = ServiceId { + name: name.into(), + service_type: service_type.into(), + port, + }; + + if self.services.remove(&id).is_some() { + info!( + "Deregistering mDNS service {}/{}/{}", + name, service_type, port + ); + } + + Ok(()) + } +} + +impl Mdns for LinuxMdns { + fn add( + &mut self, + name: &str, + service_type: &str, + port: u16, + txt_kvs: &[(&str, &str)], + ) -> Result<(), Error> { + LinuxMdns::add(self, name, service_type, port, txt_kvs) + } + + fn remove(&mut self, name: &str, service_type: &str, port: u16) -> Result<(), Error> { + LinuxMdns::remove(self, name, service_type, port) + } } diff --git a/matter/src/sys/sys_macos.rs b/matter/src/sys/sys_macos.rs index ba2ce223..d8ffe3ac 100644 --- a/matter/src/sys/sys_macos.rs +++ b/matter/src/sys/sys_macos.rs @@ -15,32 +15,95 @@ * limitations under the License. */ -use crate::error::Error; +use std::collections::HashMap; + +use crate::{error::Error, mdns::Mdns}; use astro_dnssd::{DNSServiceBuilder, RegisteredDnsService}; use log::info; -#[allow(dead_code)] -pub struct SysMdnsService { - s: RegisteredDnsService, +#[derive(Debug, Clone, Eq, PartialEq, Hash)] +pub struct ServiceId { + name: String, + service_type: String, + port: u16, } -/// Publish a mDNS service -/// name - can be a service name (comma separate subtypes may follow) -/// regtype - registration type (e.g. _matter_.tcp etc) -/// port - the port -pub fn sys_publish_service( - name: &str, - regtype: &str, - port: u16, - txt_kvs: &[[&str; 2]], -) -> Result { - let mut builder = DNSServiceBuilder::new(regtype, port).with_name(name); - - info!("mDNS Registration Type {}", regtype); - for kvs in txt_kvs { - info!("mDNS TXT key {} val {}", kvs[0], kvs[1]); - builder = builder.with_key_value(kvs[0].to_string(), kvs[1].to_string()); +pub struct MacOsMdns { + services: HashMap, +} + +impl MacOsMdns { + pub fn new() -> Result { + Ok(Self { + services: HashMap::new(), + }) + } + + pub fn add( + &mut self, + name: &str, + service_type: &str, + port: u16, + txt_kvs: &[(&str, &str)], + ) -> Result<(), Error> { + info!( + "Registering mDNS service {}/{}/{}", + name, service_type, port + ); + + let _ = self.remove(name, service_type, port); + + let mut builder = DNSServiceBuilder::new(service_type, port).with_name(name); + + for kvs in txt_kvs { + info!("mDNS TXT key {} val {}", kvs.0, kvs.1); + builder = builder.with_key_value(kvs.0.to_string(), kvs.1.to_string()); + } + + let service = builder.register().map_err(|_| Error::MdnsError)?; + + self.services.insert( + ServiceId { + name: name.into(), + service_type: service_type.into(), + port, + }, + service, + ); + + Ok(()) + } + + pub fn remove(&mut self, name: &str, service_type: &str, port: u16) -> Result<(), Error> { + let id = ServiceId { + name: name.into(), + service_type: service_type.into(), + port, + }; + + if self.services.remove(&id).is_some() { + info!( + "Deregistering mDNS service {}/{}/{}", + name, service_type, port + ); + } + + Ok(()) + } +} + +impl Mdns for MacOsMdns { + fn add( + &mut self, + name: &str, + service_type: &str, + port: u16, + txt_kvs: &[(&str, &str)], + ) -> Result<(), Error> { + MacOsMdns::add(self, name, service_type, port, txt_kvs) + } + + fn remove(&mut self, name: &str, service_type: &str, port: u16) -> Result<(), Error> { + MacOsMdns::remove(self, name, service_type, port) } - let s = builder.register().map_err(|_| Error::MdnsError)?; - Ok(SysMdnsService { s }) } From d9c99d73eef72e0b0f94ae8f8afd5348f5254910 Mon Sep 17 00:00:00 2001 From: ivmarkov Date: Mon, 24 Apr 2023 08:17:08 +0000 Subject: [PATCH 08/72] Chrono dep made optional --- matter/src/cert/asn1_writer.rs | 50 ++++++++++++++------------ matter/src/cert/mod.rs | 60 ++++++++++++++++++++----------- matter/src/cert/printer.rs | 29 +++++---------- matter/src/core.rs | 14 +++++++- matter/src/secure_channel/case.rs | 24 +++++++++---- matter/src/secure_channel/core.rs | 12 +++++-- matter/src/transport/mgr.rs | 16 +++++++-- matter/src/utils/epoch.rs | 47 ++++++++++++++++++++++++ matter/tests/common/im_engine.rs | 8 +++-- 9 files changed, 181 insertions(+), 79 deletions(-) diff --git a/matter/src/cert/asn1_writer.rs b/matter/src/cert/asn1_writer.rs index 675546a0..b6f4ab78 100644 --- a/matter/src/cert/asn1_writer.rs +++ b/matter/src/cert/asn1_writer.rs @@ -16,10 +16,11 @@ */ use super::{CertConsumer, MAX_DEPTH}; -use crate::error::Error; -use chrono::{Datelike, TimeZone, Utc}; // TODO -use core::fmt::Write; -use log::warn; +use crate::{ + error::Error, + utils::epoch::{UtcCalendar, MATTER_EPOCH_SECS}, +}; +use core::{fmt::Write, time::Duration}; #[derive(Debug)] pub struct ASN1Writer<'a> { @@ -261,31 +262,34 @@ impl<'a> CertConsumer for ASN1Writer<'a> { self.write_str(0x06, oid) } - fn utctime(&mut self, _tag: &str, epoch: u32) -> Result<(), Error> { - let mut matter_epoch = Utc - .with_ymd_and_hms(2000, 1, 1, 0, 0, 0) - .unwrap() - .timestamp(); + fn utctime(&mut self, _tag: &str, epoch: u32, utc_calendar: UtcCalendar) -> Result<(), Error> { + let matter_epoch = MATTER_EPOCH_SECS + epoch as u64; - matter_epoch += epoch as i64; + let dt = utc_calendar(Duration::from_secs(matter_epoch as _)); - let dt = match Utc.timestamp_opt(matter_epoch, 0) { - chrono::LocalResult::None => return Err(Error::InvalidTime), - chrono::LocalResult::Single(s) => s, - chrono::LocalResult::Ambiguous(_, a) => { - warn!("Ambiguous time for epoch {epoch}; returning latest timestamp: {a}"); - a - } - }; + let mut time_str: heapless::String<32> = heapless::String::<32>::new(); - if dt.year() >= 2050 { + if dt.year >= 2050 { // If year is >= 2050, ASN.1 requires it to be Generalised Time - let mut time_str = heapless::String::<32>::new(); - write!(&mut time_str, "{}Z", dt.format("%Y%m%d%H%M%S")).unwrap(); + write!( + &mut time_str, + "{:04}{:02}{:02}{:02}{:02}{:02}Z", + dt.year, dt.month, dt.day, dt.hour, dt.minute, dt.second + ) + .unwrap(); self.write_str(0x18, time_str.as_bytes()) } else { - let mut time_str = heapless::String::<32>::new(); - write!(&mut time_str, "{}Z", dt.format("%y%m%d%H%M%S")).unwrap(); + write!( + &mut time_str, + "{:02}{:02}{:02}{:02}{:02}{:02}Z", + dt.year % 100, + dt.month, + dt.day, + dt.hour, + dt.minute, + dt.second + ) + .unwrap(); self.write_str(0x17, time_str.as_bytes()) } } diff --git a/matter/src/cert/mod.rs b/matter/src/cert/mod.rs index 918737af..7af18677 100644 --- a/matter/src/cert/mod.rs +++ b/matter/src/cert/mod.rs @@ -21,7 +21,7 @@ use crate::{ crypto::KeyPair, error::Error, tlv::{self, FromTLV, OctetStr, TLVArray, TLVElement, TLVWriter, TagType, ToTLV}, - utils::writebuf::WriteBuf, + utils::{epoch::UtcCalendar, writebuf::WriteBuf}, }; use log::error; use num_derive::FromPrimitive; @@ -617,17 +617,21 @@ impl<'a> Cert<'a> { Ok(wb.as_slice().len()) } - pub fn as_asn1(&self, buf: &mut [u8]) -> Result { + pub fn as_asn1(&self, buf: &mut [u8], utc_calendar: UtcCalendar) -> Result { let mut w = ASN1Writer::new(buf); - self.encode(&mut w)?; + self.encode(&mut w, Some(utc_calendar))?; Ok(w.as_slice().len()) } - pub fn verify_chain_start(&self) -> CertVerifier { - CertVerifier::new(self) + pub fn verify_chain_start(&self, utc_calendar: UtcCalendar) -> CertVerifier { + CertVerifier::new(self, utc_calendar) } - fn encode(&self, w: &mut dyn CertConsumer) -> Result<(), Error> { + fn encode( + &self, + w: &mut dyn CertConsumer, + utc_calendar: Option, + ) -> Result<(), Error> { w.start_seq("")?; w.start_ctx("Version:", 0)?; @@ -646,8 +650,10 @@ impl<'a> Cert<'a> { self.issuer.encode("Issuer:", w)?; w.start_seq("Validity:")?; - w.utctime("Not Before:", self.not_before)?; - w.utctime("Not After:", self.not_after)?; + if let Some(utc_calendar) = utc_calendar { + w.utctime("Not Before:", self.not_before, utc_calendar)?; + w.utctime("Not After:", self.not_after, utc_calendar)?; + } w.end_seq()?; self.subject.encode("Subject:", w)?; @@ -679,7 +685,7 @@ impl<'a> fmt::Display for Cert<'a> { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { let mut printer = CertPrinter::new(f); let _ = self - .encode(&mut printer) + .encode(&mut printer, None) .map_err(|e| error!("Error decoding certificate: {}", e)); // Signature is not encoded by the Cert Decoder writeln!(f, "Signature: {:x?}", self.get_signature()) @@ -688,11 +694,12 @@ impl<'a> fmt::Display for Cert<'a> { pub struct CertVerifier<'a> { cert: &'a Cert<'a>, + utc_calendar: UtcCalendar, } impl<'a> CertVerifier<'a> { - pub fn new(cert: &'a Cert) -> Self { - Self { cert } + pub fn new(cert: &'a Cert, utc_calendar: UtcCalendar) -> Self { + Self { cert, utc_calendar } } pub fn add_cert(self, parent: &'a Cert) -> Result, Error> { @@ -700,7 +707,7 @@ impl<'a> CertVerifier<'a> { return Err(Error::InvalidAuthKey); } let mut asn1 = [0u8; MAX_ASN1_CERT_SIZE]; - let len = self.cert.as_asn1(&mut asn1)?; + let len = self.cert.as_asn1(&mut asn1, self.utc_calendar)?; let asn1 = &asn1[..len]; let k = KeyPair::new_from_public(parent.get_pubkey())?; @@ -713,7 +720,7 @@ impl<'a> CertVerifier<'a> { })?; // TODO: other validation checks - Ok(CertVerifier::new(parent)) + Ok(CertVerifier::new(parent, self.utc_calendar)) } pub fn finalise(self) -> Result<(), Error> { @@ -740,7 +747,7 @@ pub trait CertConsumer { fn start_ctx(&mut self, tag: &str, id: u8) -> Result<(), Error>; fn end_ctx(&mut self) -> Result<(), Error>; fn oid(&mut self, tag: &str, oid: &[u8]) -> Result<(), Error>; - fn utctime(&mut self, tag: &str, epoch: u32) -> Result<(), Error>; + fn utctime(&mut self, tag: &str, epoch: u32, utc_calendar: UtcCalendar) -> Result<(), Error>; } const MAX_DEPTH: usize = 10; @@ -758,36 +765,44 @@ mod tests { use crate::tlv::{self, FromTLV, TLVWriter, TagType, ToTLV}; use crate::utils::writebuf::WriteBuf; + #[cfg(feature = "std")] #[test] fn test_asn1_encode_success() { { let mut asn1_buf = [0u8; 1000]; let c = Cert::new(&test_vectors::CHIP_CERT_INPUT1).unwrap(); - let len = c.as_asn1(&mut asn1_buf).unwrap(); + let len = c + .as_asn1(&mut asn1_buf, crate::utils::epoch::sys_utc_calendar) + .unwrap(); assert_eq!(&test_vectors::ASN1_OUTPUT1, &asn1_buf[..len]); } { let mut asn1_buf = [0u8; 1000]; let c = Cert::new(&test_vectors::CHIP_CERT_INPUT2).unwrap(); - let len = c.as_asn1(&mut asn1_buf).unwrap(); + let len = c + .as_asn1(&mut asn1_buf, crate::utils::epoch::sys_utc_calendar) + .unwrap(); assert_eq!(&test_vectors::ASN1_OUTPUT2, &asn1_buf[..len]); } { let mut asn1_buf = [0u8; 1000]; let c = Cert::new(&test_vectors::CHIP_CERT_TXT_IN_DN).unwrap(); - let len = c.as_asn1(&mut asn1_buf).unwrap(); + let len = c + .as_asn1(&mut asn1_buf, crate::utils::epoch::sys_utc_calendar) + .unwrap(); assert_eq!(&test_vectors::ASN1_OUTPUT_TXT_IN_DN, &asn1_buf[..len]); } } + #[cfg(feature = "std")] #[test] fn test_verify_chain_success() { let noc = Cert::new(&test_vectors::NOC1_SUCCESS).unwrap(); let icac = Cert::new(&test_vectors::ICAC1_SUCCESS).unwrap(); let rca = Cert::new(&test_vectors::RCA1_SUCCESS).unwrap(); - let a = noc.verify_chain_start(); + let a = noc.verify_chain_start(crate::utils::epoch::sys_utc_calendar); a.add_cert(&icac) .unwrap() .add_cert(&rca) @@ -796,31 +811,34 @@ mod tests { .unwrap(); } + #[cfg(feature = "std")] #[test] fn test_verify_chain_incomplete() { // The chain doesn't lead up to a self-signed certificate let noc = Cert::new(&test_vectors::NOC1_SUCCESS).unwrap(); let icac = Cert::new(&test_vectors::ICAC1_SUCCESS).unwrap(); - let a = noc.verify_chain_start(); + let a = noc.verify_chain_start(crate::utils::epoch::sys_utc_calendar); assert_eq!( Err(Error::InvalidAuthKey), a.add_cert(&icac).unwrap().finalise() ); } + #[cfg(feature = "std")] #[test] fn test_auth_key_chain_incorrect() { let noc = Cert::new(&test_vectors::NOC1_AUTH_KEY_FAIL).unwrap(); let icac = Cert::new(&test_vectors::ICAC1_SUCCESS).unwrap(); - let a = noc.verify_chain_start(); + let a = noc.verify_chain_start(crate::utils::epoch::sys_utc_calendar); assert_eq!(Err(Error::InvalidAuthKey), a.add_cert(&icac).map(|_| ())); } + #[cfg(feature = "std")] #[test] fn test_cert_corrupted() { let noc = Cert::new(&test_vectors::NOC1_CORRUPT_CERT).unwrap(); let icac = Cert::new(&test_vectors::ICAC1_SUCCESS).unwrap(); - let a = noc.verify_chain_start(); + let a = noc.verify_chain_start(crate::utils::epoch::sys_utc_calendar); assert_eq!(Err(Error::InvalidSignature), a.add_cert(&icac).map(|_| ())); } diff --git a/matter/src/cert/printer.rs b/matter/src/cert/printer.rs index b9336073..ae079573 100644 --- a/matter/src/cert/printer.rs +++ b/matter/src/cert/printer.rs @@ -16,10 +16,11 @@ */ use super::{CertConsumer, MAX_DEPTH}; -use crate::error::Error; -use chrono::{TimeZone, Utc}; -use core::fmt; -use log::warn; +use crate::{ + error::Error, + utils::epoch::{UtcCalendar, MATTER_EPOCH_SECS}, +}; +use core::{fmt, time::Duration}; pub struct CertPrinter<'a, 'b> { level: usize, @@ -122,24 +123,12 @@ impl<'a, 'b> CertConsumer for CertPrinter<'a, 'b> { } Ok(()) } - fn utctime(&mut self, tag: &str, epoch: u32) -> Result<(), Error> { - let mut matter_epoch = Utc - .with_ymd_and_hms(2000, 1, 1, 0, 0, 0) - .unwrap() - .timestamp(); + fn utctime(&mut self, tag: &str, epoch: u32, utc_calendar: UtcCalendar) -> Result<(), Error> { + let matter_epoch = MATTER_EPOCH_SECS + epoch as u64; - matter_epoch += epoch as i64; + let dt = utc_calendar(Duration::from_secs(matter_epoch as _)); - let dt = match Utc.timestamp_opt(matter_epoch, 0) { - chrono::LocalResult::None => return Err(Error::InvalidTime), - chrono::LocalResult::Single(s) => s, - chrono::LocalResult::Ambiguous(_, a) => { - warn!("Ambiguous time for epoch {epoch}; returning latest timestamp: {a}"); - a - } - }; - - let _ = writeln!(self.f, "{} {} {}", SPACE[self.level], tag, dt); + let _ = writeln!(self.f, "{} {} {:?}", SPACE[self.level], tag, dt); Ok(()) } } diff --git a/matter/src/core.rs b/matter/src/core.rs index 1c3eb25a..0b555cc5 100644 --- a/matter/src/core.rs +++ b/matter/src/core.rs @@ -26,7 +26,10 @@ use crate::{ pairing::{print_pairing_code_and_qr, DiscoveryCapabilities}, secure_channel::{pake::PaseMgr, spake2p::VerifierData}, transport::udp::MATTER_PORT, - utils::{epoch::Epoch, rand::Rand}, + utils::{ + epoch::{Epoch, UtcCalendar}, + rand::Rand, + }, }; /// Device Commissioning Data @@ -46,6 +49,7 @@ pub struct Matter<'a> { pub mdns_mgr: RefCell>, pub epoch: Epoch, pub rand: Rand, + pub utc_calendar: UtcCalendar, pub dev_det: &'a BasicInfoConfig<'a>, } @@ -61,6 +65,7 @@ impl<'a> Matter<'a> { mdns: &'a mut dyn Mdns, epoch: Epoch, rand: Rand, + utc_calendar: UtcCalendar, ) -> Self { Self { fabric_mgr: RefCell::new(FabricMgr::new()), @@ -76,6 +81,7 @@ impl<'a> Matter<'a> { )), epoch, rand, + utc_calendar, dev_det, } } @@ -150,3 +156,9 @@ impl<'a> Borrow for Matter<'a> { &self.rand } } + +impl<'a> Borrow for Matter<'a> { + fn borrow(&self) -> &UtcCalendar { + &self.utc_calendar + } +} diff --git a/matter/src/secure_channel/case.rs b/matter/src/secure_channel/case.rs index a722dae8..4ffb6b23 100644 --- a/matter/src/secure_channel/case.rs +++ b/matter/src/secure_channel/case.rs @@ -33,7 +33,7 @@ use crate::{ queue::{Msg, WorkQ}, session::{CaseDetails, CloneData, NocCatIds, SessionMode}, }, - utils::{rand::Rand, writebuf::WriteBuf}, + utils::{epoch::UtcCalendar, rand::Rand, writebuf::WriteBuf}, }; #[derive(PartialEq)] @@ -71,11 +71,16 @@ impl CaseSession { pub struct Case<'a> { fabric_mgr: &'a RefCell, rand: Rand, + utc_calendar: UtcCalendar, } impl<'a> Case<'a> { - pub fn new(fabric_mgr: &'a RefCell, rand: Rand) -> Self { - Self { fabric_mgr, rand } + pub fn new(fabric_mgr: &'a RefCell, rand: Rand, utc_calendar: UtcCalendar) -> Self { + Self { + fabric_mgr, + rand, + utc_calendar, + } } pub fn casesigma3_handler(&mut self, ctx: &mut ProtoCtx) -> Result { @@ -126,7 +131,9 @@ impl<'a> Case<'a> { if let Some(icac) = d.initiator_icac { initiator_icac = Some(Cert::new(icac.0)?); } - if let Err(e) = Case::validate_certs(fabric, &initiator_noc, &initiator_icac) { + if let Err(e) = + Case::validate_certs(fabric, &initiator_noc, &initiator_icac, self.utc_calendar) + { error!("Certificate Chain doesn't match: {}", e); common::create_sc_status_report(ctx.tx, common::SCStatusCodes::InvalidParameter, None)?; ctx.exch_ctx.exch.close(); @@ -332,8 +339,13 @@ impl<'a> Case<'a> { Ok(()) } - fn validate_certs(fabric: &Fabric, noc: &Cert, icac: &Option) -> Result<(), Error> { - let mut verifier = noc.verify_chain_start(); + fn validate_certs( + fabric: &Fabric, + noc: &Cert, + icac: &Option, + utc_calendar: UtcCalendar, + ) -> Result<(), Error> { + let mut verifier = noc.verify_chain_start(utc_calendar); if fabric.get_fabric_id() != noc.get_fabric_id()? { return Err(Error::Invalid); diff --git a/matter/src/secure_channel/core.rs b/matter/src/secure_channel/core.rs index 5ca18042..be806a74 100644 --- a/matter/src/secure_channel/core.rs +++ b/matter/src/secure_channel/core.rs @@ -18,8 +18,13 @@ use core::cell::RefCell; use crate::{ - error::*, fabric::FabricMgr, mdns::MdnsMgr, secure_channel::common::*, tlv, - transport::proto_ctx::ProtoCtx, utils::rand::Rand, + error::*, + fabric::FabricMgr, + mdns::MdnsMgr, + secure_channel::common::*, + tlv, + transport::proto_ctx::ProtoCtx, + utils::{epoch::UtcCalendar, rand::Rand}, }; use log::{error, info}; use num; @@ -41,9 +46,10 @@ impl<'a> SecureChannel<'a> { fabric_mgr: &'a RefCell, mdns: &'a RefCell>, rand: Rand, + utc_calendar: UtcCalendar, ) -> Self { SecureChannel { - case: Case::new(fabric_mgr, rand), + case: Case::new(fabric_mgr, rand, utc_calendar), pase, mdns, } diff --git a/matter/src/transport/mgr.rs b/matter/src/transport/mgr.rs index 349cfdee..2ada9429 100644 --- a/matter/src/transport/mgr.rs +++ b/matter/src/transport/mgr.rs @@ -29,7 +29,7 @@ use crate::secure_channel::common::PROTO_ID_SECURE_CHANNEL; use crate::secure_channel::core::SecureChannel; use crate::transport::mrp::ReliableMessage; use crate::transport::{exchange, packet::Packet}; -use crate::utils::epoch::Epoch; +use crate::utils::epoch::{Epoch, UtcCalendar}; use crate::utils::rand::Rand; use super::proto_ctx::ProtoCtx; @@ -167,13 +167,23 @@ pub struct TransportMgr<'a> { impl<'a> TransportMgr<'a> { pub fn new< - T: Borrow> + Borrow> + Borrow + Borrow, + T: Borrow> + + Borrow> + + Borrow + + Borrow + + Borrow, >( matter: &'a T, mdns_mgr: &'a RefCell>, ) -> Self { Self::wrap( - SecureChannel::new(matter.borrow(), matter.borrow(), mdns_mgr, *matter.borrow()), + SecureChannel::new( + matter.borrow(), + matter.borrow(), + mdns_mgr, + *matter.borrow(), + *matter.borrow(), + ), *matter.borrow(), *matter.borrow(), ) diff --git a/matter/src/utils/epoch.rs b/matter/src/utils/epoch.rs index 999cdf38..7d08bfe7 100644 --- a/matter/src/utils/epoch.rs +++ b/matter/src/utils/epoch.rs @@ -2,13 +2,60 @@ use core::time::Duration; pub type Epoch = fn() -> Duration; +pub type UtcCalendar = fn(Duration) -> UtcDate; + +pub const MATTER_EPOCH_SECS: u64 = 946684800; // Seconds from 1970/01/01 00:00:00 till 2000/01/01 00:00:00 UTC + +#[derive(Default, Debug, Clone, Eq, PartialEq)] +pub struct UtcDate { + pub year: u16, + pub month: u8, // 1 - 12 + pub day: u8, // 1 - 31 + pub hour: u8, // 0 - 23 + pub minute: u8, + pub second: u8, + pub millis: u16, +} + pub fn dummy_epoch() -> Duration { Duration::from_secs(0) } +pub fn dummy_utc_calendar(_duration: Duration) -> UtcDate { + Default::default() +} + #[cfg(feature = "std")] pub fn sys_epoch() -> Duration { std::time::SystemTime::now() .duration_since(std::time::UNIX_EPOCH) .unwrap() } + +#[cfg(feature = "std")] +pub fn sys_utc_calendar(duration: Duration) -> UtcDate { + use chrono::{Datelike, TimeZone, Timelike}; + use log::warn; + + let dt = match chrono::Utc.timestamp_opt(duration.as_secs() as _, duration.subsec_nanos()) { + chrono::LocalResult::None => panic!("Invalid time"), + chrono::LocalResult::Single(s) => s, + chrono::LocalResult::Ambiguous(_, a) => { + warn!( + "Ambiguous time for epoch {:?}; returning latest timestamp: {a}", + duration + ); + a + } + }; + + UtcDate { + year: dt.year() as _, + month: dt.month() as _, + day: dt.day() as _, + hour: dt.hour() as _, + minute: dt.minute() as _, + second: dt.second() as _, + millis: (dt.nanosecond() / 1000) as _, + } +} diff --git a/matter/tests/common/im_engine.rs b/matter/tests/common/im_engine.rs index 116ad50c..4cdbf042 100644 --- a/matter/tests/common/im_engine.rs +++ b/matter/tests/common/im_engine.rs @@ -49,7 +49,11 @@ use matter::{ proto_ctx::ProtoCtx, session::{CaseDetails, CloneData, NocCatIds, SessionMgr, SessionMode}, }, - utils::{epoch::sys_epoch, rand::dummy_rand, writebuf::WriteBuf}, + utils::{ + epoch::{sys_epoch, sys_utc_calendar}, + rand::dummy_rand, + writebuf::WriteBuf, + }, Matter, }; use std::net::{Ipv4Addr, SocketAddr}; @@ -105,7 +109,7 @@ impl<'a> ImInput<'a> { pub type DmHandler<'a> = handler_chain_type!(OnOffCluster, EchoCluster, DescriptorCluster, EchoCluster | RootEndpointHandler<'a>); pub fn matter(mdns: &mut dyn Mdns) -> Matter<'_> { - Matter::new(&BASIC_INFO, mdns, sys_epoch, dummy_rand) + Matter::new(&BASIC_INFO, mdns, sys_epoch, dummy_rand, sys_utc_calendar) } /// An Interaction Model Engine to facilitate easy testing From 505fa39e8205e8b7218c0433294dd79881b8e8d1 Mon Sep 17 00:00:00 2001 From: ivmarkov Date: Mon, 24 Apr 2023 09:00:08 +0000 Subject: [PATCH 09/72] Create new secure channel sessions without async-channel --- matter/Cargo.toml | 1 - matter/src/error.rs | 15 ------- matter/src/secure_channel/case.rs | 16 ++++---- matter/src/secure_channel/core.rs | 24 +++++++---- matter/src/secure_channel/pake.rs | 30 ++++++++------ matter/src/transport/mgr.rs | 53 ++++++++++++------------ matter/src/transport/mod.rs | 1 - matter/src/transport/proto_ctx.rs | 2 +- matter/src/transport/queue.rs | 67 ------------------------------- 9 files changed, 70 insertions(+), 139 deletions(-) delete mode 100644 matter/src/transport/queue.rs diff --git a/matter/Cargo.toml b/matter/Cargo.toml index b6260838..53bea889 100644 --- a/matter/Cargo.toml +++ b/matter/Cargo.toml @@ -45,7 +45,6 @@ smol = "1.3.0" owning_ref = "0.4.1" safemem = "0.3.3" chrono = { version = "0.4.23", default-features = false, features = ["clock", "std"] } -async-channel = "1.8" # crypto openssl = { git = "https://github.com/sfackler/rust-openssl", optional = true } diff --git a/matter/src/error.rs b/matter/src/error.rs index 04d55b3b..e644a7aa 100644 --- a/matter/src/error.rs +++ b/matter/src/error.rs @@ -17,7 +17,6 @@ use core::{array::TryFromSliceError, fmt, str::Utf8Error}; -use async_channel::{SendError, TryRecvError}; use log::error; #[derive(Debug, PartialEq, Clone, Copy)] @@ -156,26 +155,12 @@ impl From for Error { } } -impl From> for Error { - fn from(e: SendError) -> Self { - error!("Error in channel send {}", e); - Self::Invalid - } -} - impl From for Error { fn from(_e: Utf8Error) -> Self { Self::Utf8Fail } } -impl From for Error { - fn from(e: TryRecvError) -> Self { - error!("Error in channel try_recv {}", e); - Self::Invalid - } -} - impl fmt::Display for Error { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { write!(f, "{:?}", self) diff --git a/matter/src/secure_channel/case.rs b/matter/src/secure_channel/case.rs index 4ffb6b23..e681ec92 100644 --- a/matter/src/secure_channel/case.rs +++ b/matter/src/secure_channel/case.rs @@ -30,7 +30,6 @@ use crate::{ transport::{ network::Address, proto_ctx::ProtoCtx, - queue::{Msg, WorkQ}, session::{CaseDetails, CloneData, NocCatIds, SessionMode}, }, utils::{epoch::UtcCalendar, rand::Rand, writebuf::WriteBuf}, @@ -83,7 +82,10 @@ impl<'a> Case<'a> { } } - pub fn casesigma3_handler(&mut self, ctx: &mut ProtoCtx) -> Result { + pub fn casesigma3_handler( + &mut self, + ctx: &mut ProtoCtx, + ) -> Result<(bool, Option), Error> { let mut case_session = ctx .exch_ctx .exch @@ -104,7 +106,7 @@ impl<'a> Case<'a> { None, )?; ctx.exch_ctx.exch.close(); - return Ok(true); + return Ok((true, None)); } // Safe to unwrap here let fabric = fabric.unwrap(); @@ -137,7 +139,7 @@ impl<'a> Case<'a> { error!("Certificate Chain doesn't match: {}", e); common::create_sc_status_report(ctx.tx, common::SCStatusCodes::InvalidParameter, None)?; ctx.exch_ctx.exch.close(); - return Ok(true); + return Ok((true, None)); } if Case::validate_sigma3_sign( @@ -152,7 +154,7 @@ impl<'a> Case<'a> { error!("Sigma3 Signature doesn't match"); common::create_sc_status_report(ctx.tx, common::SCStatusCodes::InvalidParameter, None)?; ctx.exch_ctx.exch.close(); - return Ok(true); + return Ok((true, None)); } // Only now do we add this message to the TT Hash @@ -167,13 +169,11 @@ impl<'a> Case<'a> { &case_session, &peer_catids, )?; - // Queue a transport mgr request to add a new session - WorkQ::get()?.sync_send(Msg::NewSession(clone_data))?; common::create_sc_status_report(ctx.tx, SCStatusCodes::SessionEstablishmentSuccess, None)?; ctx.exch_ctx.exch.clear_data(); ctx.exch_ctx.exch.close(); - Ok(true) + Ok((true, Some(clone_data))) } pub fn casesigma1_handler(&mut self, ctx: &mut ProtoCtx) -> Result { diff --git a/matter/src/secure_channel/core.rs b/matter/src/secure_channel/core.rs index be806a74..e69dca59 100644 --- a/matter/src/secure_channel/core.rs +++ b/matter/src/secure_channel/core.rs @@ -23,7 +23,7 @@ use crate::{ mdns::MdnsMgr, secure_channel::common::*, tlv, - transport::proto_ctx::ProtoCtx, + transport::{proto_ctx::ProtoCtx, session::CloneData}, utils::{epoch::UtcCalendar, rand::Rand}, }; use log::{error, info}; @@ -55,22 +55,30 @@ impl<'a> SecureChannel<'a> { } } - pub fn handle(&mut self, ctx: &mut ProtoCtx) -> Result { + pub fn handle(&mut self, ctx: &mut ProtoCtx) -> Result<(bool, Option), Error> { let proto_opcode: OpCode = num::FromPrimitive::from_u8(ctx.rx.get_proto_opcode()).ok_or(Error::Invalid)?; ctx.tx.set_proto_id(PROTO_ID_SECURE_CHANNEL); info!("Received Opcode: {:?}", proto_opcode); info!("Received Data:"); tlv::print_tlv_list(ctx.rx.as_slice()); - let reply = match proto_opcode { - OpCode::MRPStandAloneAck => Ok(true), - OpCode::PBKDFParamRequest => self.pase.borrow_mut().pbkdfparamreq_handler(ctx), - OpCode::PASEPake1 => self.pase.borrow_mut().pasepake1_handler(ctx), + let (reply, clone_data) = match proto_opcode { + OpCode::MRPStandAloneAck => Ok((true, None)), + OpCode::PBKDFParamRequest => self + .pase + .borrow_mut() + .pbkdfparamreq_handler(ctx) + .map(|reply| (reply, None)), + OpCode::PASEPake1 => self + .pase + .borrow_mut() + .pasepake1_handler(ctx) + .map(|reply| (reply, None)), OpCode::PASEPake3 => self .pase .borrow_mut() .pasepake3_handler(ctx, &mut self.mdns.borrow_mut()), - OpCode::CASESigma1 => self.case.casesigma1_handler(ctx), + OpCode::CASESigma1 => self.case.casesigma1_handler(ctx).map(|reply| (reply, None)), OpCode::CASESigma3 => self.case.casesigma3_handler(ctx), _ => { error!("OpCode Not Handled: {:?}", proto_opcode); @@ -83,6 +91,6 @@ impl<'a> SecureChannel<'a> { tlv::print_tlv_list(ctx.tx.as_mut_slice()); } - Ok(reply) + Ok((reply, clone_data)) } } diff --git a/matter/src/secure_channel/pake.rs b/matter/src/secure_channel/pake.rs index ce05fb65..1901686c 100644 --- a/matter/src/secure_channel/pake.rs +++ b/matter/src/secure_channel/pake.rs @@ -31,7 +31,6 @@ use crate::{ exchange::ExchangeCtx, network::Address, proto_ctx::ProtoCtx, - queue::{Msg, WorkQ}, session::{CloneData, SessionMode}, }, utils::{epoch::Epoch, rand::Rand}, @@ -101,15 +100,18 @@ impl PaseMgr { /// If the PASE Session is enabled, execute the closure, /// if not enabled, generate SC Status Report - fn if_enabled(&mut self, ctx: &mut ProtoCtx, f: F) -> Result<(), Error> + fn if_enabled(&mut self, ctx: &mut ProtoCtx, f: F) -> Result, Error> where - F: FnOnce(&mut Pake, &mut ProtoCtx) -> Result<(), Error>, + F: FnOnce(&mut Pake, &mut ProtoCtx) -> Result, { if let PaseMgrState::Enabled(pake, _, _) = &mut self.state { - f(pake, ctx) + let data = f(pake, ctx)?; + + Ok(Some(data)) } else { error!("PASE Not enabled"); - create_sc_status_report(ctx.tx, SCStatusCodes::InvalidParameter, None) + create_sc_status_report(ctx.tx, SCStatusCodes::InvalidParameter, None)?; + Ok(None) } } @@ -129,10 +131,10 @@ impl PaseMgr { &mut self, ctx: &mut ProtoCtx, mdns: &mut MdnsMgr, - ) -> Result { - self.if_enabled(ctx, |pake, ctx| pake.handle_pasepake3(ctx))?; + ) -> Result<(bool, Option), Error> { + let clone_data = self.if_enabled(ctx, |pake, ctx| pake.handle_pasepake3(ctx))?; self.disable_pase_session(mdns)?; - Ok(true) + Ok((true, clone_data.flatten())) } } @@ -230,13 +232,13 @@ impl Pake { } #[allow(non_snake_case)] - pub fn handle_pasepake3(&mut self, ctx: &mut ProtoCtx) -> Result<(), Error> { + pub fn handle_pasepake3(&mut self, ctx: &mut ProtoCtx) -> Result, Error> { let mut sd = self.state.take_sess_data(&ctx.exch_ctx)?; let cA = extract_pasepake_1_or_3_params(ctx.rx.as_slice())?; let (status_code, ke) = sd.spake2p.handle_cA(cA); - if status_code == SCStatusCodes::SessionEstablishmentSuccess { + let clone_data = if status_code == SCStatusCodes::SessionEstablishmentSuccess { // Get the keys let ke = ke.ok_or(Error::Invalid)?; let mut session_keys: [u8; 48] = [0; 48]; @@ -262,12 +264,14 @@ impl Pake { .copy_from_slice(&session_keys[32..48]); // Queue a transport mgr request to add a new session - WorkQ::get()?.sync_send(Msg::NewSession(clone_data))?; - } + Some(clone_data) + } else { + None + }; create_sc_status_report(ctx.tx, status_code, None)?; ctx.exch_ctx.exch.close(); - Ok(()) + Ok(clone_data) } #[allow(non_snake_case)] diff --git a/matter/src/transport/mgr.rs b/matter/src/transport/mgr.rs index 2ada9429..1d68f344 100644 --- a/matter/src/transport/mgr.rs +++ b/matter/src/transport/mgr.rs @@ -33,12 +33,14 @@ use crate::utils::epoch::{Epoch, UtcCalendar}; use crate::utils::rand::Rand; use super::proto_ctx::ProtoCtx; +use super::session::CloneData; -#[derive(Copy, Clone, Eq, PartialEq)] enum RecvState { New, OpenExchange, + AddSession(CloneData), EvictSession, + EvictSession2(CloneData), Ack, } @@ -69,7 +71,7 @@ impl<'r, 'a, 'p> RecvCompletion<'r, 'a, 'p> { fn maybe_next_action(&mut self) -> Result>>, Error> { self.mgr.exch_mgr.purge(); - match self.state { + match core::mem::replace(&mut self.state, RecvState::New) { RecvState::New => { self.mgr.exch_mgr.get_sess_mgr().decode(self.rx)?; self.state = RecvState::OpenExchange; @@ -80,13 +82,18 @@ impl<'r, 'a, 'p> RecvCompletion<'r, 'a, 'p> { if self.rx.get_proto_id() == PROTO_ID_SECURE_CHANNEL { let mut proto_ctx = ProtoCtx::new(exch_ctx, self.rx, self.tx); - if self.mgr.secure_channel.handle(&mut proto_ctx)? { - proto_ctx.send()?; + let (reply, clone_data) = self.mgr.secure_channel.handle(&mut proto_ctx)?; + if let Some(clone_data) = clone_data { + self.state = RecvState::AddSession(clone_data); + } else { self.state = RecvState::Ack; + } + + if reply { + proto_ctx.send()?; Ok(Some(Some(RecvAction::Send(self.tx.as_slice())))) } else { - self.state = RecvState::Ack; Ok(None) } } else { @@ -106,11 +113,27 @@ impl<'r, 'a, 'p> RecvCompletion<'r, 'a, 'p> { } Err(err) => Err(err), }, + RecvState::AddSession(clone_data) => match self.mgr.exch_mgr.add_session(&clone_data) { + Ok(_) => { + self.state = RecvState::Ack; + Ok(None) + } + Err(Error::NoSpace) => { + self.state = RecvState::EvictSession2(clone_data); + Ok(None) + } + Err(err) => Err(err), + }, RecvState::EvictSession => { self.mgr.exch_mgr.evict_session(self.tx)?; self.state = RecvState::OpenExchange; Ok(Some(Some(RecvAction::Send(self.tx.as_slice())))) } + RecvState::EvictSession2(clone_data) => { + self.mgr.exch_mgr.evict_session(self.tx)?; + self.state = RecvState::AddSession(clone_data); + Ok(Some(Some(RecvAction::Send(self.tx.as_slice())))) + } RecvState::Ack => { if let Some(exch_id) = self.mgr.exch_mgr.pending_ack() { info!("Sending MRP Standalone ACK for exch {}", exch_id); @@ -127,7 +150,6 @@ impl<'r, 'a, 'p> RecvCompletion<'r, 'a, 'p> { } } -#[derive(Copy, Clone, Eq, PartialEq)] enum NotifyState {} pub enum NotifyAction<'r, 'p> { @@ -212,23 +234,4 @@ impl<'a> TransportMgr<'a> { pub fn notify(&mut self, _tx: &mut Packet) -> Result { Ok(false) } - - // async fn handle_queue_msgs(&mut self) -> Result<(), Error> { - // if let Ok(msg) = self.rx_q.try_recv() { - // match msg { - // Msg::NewSession(clone_data) => { - // // If a new session was created, add it - // let _ = self - // .exch_mgr - // .add_session(&clone_data) - // .await - // .map_err(|e| error!("Error adding new session {:?}", e)); - // } - // _ => { - // error!("Queue Message Type not yet handled {:?}", msg); - // } - // } - // } - // Ok(()) - // } } diff --git a/matter/src/transport/mod.rs b/matter/src/transport/mod.rs index 43acccd7..1a81c75c 100644 --- a/matter/src/transport/mod.rs +++ b/matter/src/transport/mod.rs @@ -24,6 +24,5 @@ pub mod packet; pub mod plain_hdr; pub mod proto_ctx; pub mod proto_hdr; -pub mod queue; pub mod session; pub mod udp; diff --git a/matter/src/transport/proto_ctx.rs b/matter/src/transport/proto_ctx.rs index 747a1e6a..c4bf7f38 100644 --- a/matter/src/transport/proto_ctx.rs +++ b/matter/src/transport/proto_ctx.rs @@ -38,6 +38,6 @@ impl<'a, 'b> ProtoCtx<'a, 'b> { pub fn send(&mut self) -> Result<&[u8], Error> { self.exch_ctx.exch.send(self.tx, &mut self.exch_ctx.sess)?; - Ok(self.tx.as_mut_slice()) + Ok(self.tx.as_slice()) } } diff --git a/matter/src/transport/queue.rs b/matter/src/transport/queue.rs deleted file mode 100644 index b0c0f37f..00000000 --- a/matter/src/transport/queue.rs +++ /dev/null @@ -1,67 +0,0 @@ -/* - * - * Copyright (c) 2020-2022 Project CHIP Authors - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -use std::sync::Once; - -use async_channel::{bounded, Receiver, Sender}; - -use crate::error::Error; - -use super::session::CloneData; - -#[derive(Debug)] -pub enum Msg { - Tx(), - Rx(), - NewSession(CloneData), -} - -#[derive(Clone)] -pub struct WorkQ { - tx: Sender, -} - -static mut G_WQ: Option = None; -static INIT: Once = Once::new(); - -impl WorkQ { - pub fn init() -> Result, Error> { - let (tx, rx) = bounded::(3); - WorkQ::configure(tx); - Ok(rx) - } - - fn configure(tx: Sender) { - unsafe { - INIT.call_once(|| { - G_WQ = Some(WorkQ { tx }); - }); - } - } - - pub fn get() -> Result { - unsafe { G_WQ.as_ref().cloned().ok_or(Error::Invalid) } - } - - pub fn sync_send(&self, msg: Msg) -> Result<(), Error> { - smol::block_on(self.send(msg)) - } - - pub async fn send(&self, msg: Msg) -> Result<(), Error> { - self.tx.send(msg).await.map_err(|e| e.into()) - } -} From 688d7ea8d52cd401c366724012fb698dab0e830c Mon Sep 17 00:00:00 2001 From: ivmarkov Date: Mon, 24 Apr 2023 09:26:04 +0000 Subject: [PATCH 10/72] More ergonomic api when STD is available --- matter/src/core.rs | 8 ++++++++ matter/src/transport/mgr.rs | 4 ++-- matter/src/utils/rand.rs | 7 +++++++ 3 files changed, 17 insertions(+), 2 deletions(-) diff --git a/matter/src/core.rs b/matter/src/core.rs index 0b555cc5..24e90cf0 100644 --- a/matter/src/core.rs +++ b/matter/src/core.rs @@ -54,6 +54,14 @@ pub struct Matter<'a> { } impl<'a> Matter<'a> { + #[cfg(feature = "std")] + pub fn new_default(dev_det: &'a BasicInfoConfig, mdns: &'a mut dyn Mdns) -> Self { + use crate::utils::epoch::{sys_epoch, sys_utc_calendar}; + use crate::utils::rand::sys_rand; + + Self::new(dev_det, mdns, sys_epoch, sys_rand, sys_utc_calendar) + } + /// Creates a new Matter object /// /// # Parameters diff --git a/matter/src/transport/mgr.rs b/matter/src/transport/mgr.rs index 1d68f344..16925b2a 100644 --- a/matter/src/transport/mgr.rs +++ b/matter/src/transport/mgr.rs @@ -191,18 +191,18 @@ impl<'a> TransportMgr<'a> { pub fn new< T: Borrow> + Borrow> + + Borrow>> + Borrow + Borrow + Borrow, >( matter: &'a T, - mdns_mgr: &'a RefCell>, ) -> Self { Self::wrap( SecureChannel::new( matter.borrow(), matter.borrow(), - mdns_mgr, + matter.borrow(), *matter.borrow(), *matter.borrow(), ), diff --git a/matter/src/utils/rand.rs b/matter/src/utils/rand.rs index 3cd698ca..59a89727 100644 --- a/matter/src/utils/rand.rs +++ b/matter/src/utils/rand.rs @@ -1,3 +1,10 @@ pub type Rand = fn(&mut [u8]); pub fn dummy_rand(_buf: &mut [u8]) {} + +#[cfg(feature = "std")] +pub fn sys_rand(buf: &mut [u8]) { + use rand::{thread_rng, RngCore}; + + thread_rng().fill_bytes(buf); +} From d9349120076a67a1bda3bea79751c13a715d9898 Mon Sep 17 00:00:00 2001 From: ivmarkov Date: Mon, 24 Apr 2023 09:44:44 +0000 Subject: [PATCH 11/72] Fix compilation error since the introduction of UtcCalendar --- tools/tlv_tool/Cargo.toml | 2 +- tools/tlv_tool/src/main.rs | 5 ++--- 2 files changed, 3 insertions(+), 4 deletions(-) diff --git a/tools/tlv_tool/Cargo.toml b/tools/tlv_tool/Cargo.toml index 2aaabe80..f8c1e232 100644 --- a/tools/tlv_tool/Cargo.toml +++ b/tools/tlv_tool/Cargo.toml @@ -6,7 +6,7 @@ edition = "2018" # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html [dependencies] -matter-iot= { path = "../../matter" } +matter-iot = { path = "../../matter" } log = {version = "0.4.14", features = ["max_level_trace", "release_max_level_warn"]} simple_logger = "1.16.0" clap = "2.34" diff --git a/tools/tlv_tool/src/main.rs b/tools/tlv_tool/src/main.rs index 821e0d00..cc08cb47 100644 --- a/tools/tlv_tool/src/main.rs +++ b/tools/tlv_tool/src/main.rs @@ -15,13 +15,12 @@ * limitations under the License. */ -extern crate clap; use clap::{App, Arg}; use matter::cert; use matter::tlv; +use matter::utils::epoch::sys_utc_calendar; use simple_logger::SimpleLogger; use std::process; -use std::u8; fn main() { SimpleLogger::new() @@ -96,7 +95,7 @@ fn main() { } else if m.is_present("as-asn1") { let mut asn1_cert = [0_u8; 1024]; let cert = cert::Cert::new(&tlv_list[..index]).unwrap(); - let len = cert.as_asn1(&mut asn1_cert).unwrap(); + let len = cert.as_asn1(&mut asn1_cert, sys_utc_calendar).unwrap(); println!("{:02x?}", &asn1_cert[..len]); } else { tlv::print_tlv_list(&tlv_list[..index]); From d558c73f8d90100b1c8a7443a30b18814a6474d4 Mon Sep 17 00:00:00 2001 From: ivmarkov Date: Mon, 24 Apr 2023 09:45:03 +0000 Subject: [PATCH 12/72] Cleanup the dependencies as much as possible --- matter/Cargo.toml | 29 +++++++++++++---------------- 1 file changed, 13 insertions(+), 16 deletions(-) diff --git a/matter/Cargo.toml b/matter/Cargo.toml index 53bea889..d649645a 100644 --- a/matter/Cargo.toml +++ b/matter/Cargo.toml @@ -16,7 +16,7 @@ path = "src/lib.rs" [features] default = ["std", "crypto_mbedtls"] -std = ["alloc"] +std = ["alloc", "env_logger", "chrono", "rand", "qrcode", "smol"] alloc = [] nightly = [] crypto_openssl = ["openssl", "foreign-types", "hmac", "sha2"] @@ -29,27 +29,28 @@ boxslab = { path = "../boxslab" } matter_macro_derive = { path = "../matter_macro_derive" } bitflags = "1.3" byteorder = "1.4.3" -heapless = { version = "0.7.16", features = ["x86-sync-pool"] } -generic-array = "0.14.6" +heapless = "0.7.16" num = "0.4" num-derive = "0.3.3" num-traits = "0.2.15" strum = { version = "0.24", features = ["derive"], default-features = false, no-default-feature = true } log = { version = "0.4.17", features = ["max_level_debug", "release_max_level_debug"] } -env_logger = { version = "0.10.0", default-features = false, features = [] } -rand = "0.8.5" -esp-idf-sys = { version = "0.32", optional = true } subtle = "2.4.1" -colored = "2.0.0" -smol = "1.3.0" -owning_ref = "0.4.1" safemem = "0.3.3" -chrono = { version = "0.4.23", default-features = false, features = ["clock", "std"] } +colored = "2.0.0" # TODO: Requires STD + +# STD-only dependencies +env_logger = { version = "0.10.0", default-features = false, optional = true } +chrono = { version = "0.4.23", optional = true, default-features = false, features = ["clock", "std"] } +rand = { version = "0.8.5", optional = true } +smol = { version = "1.3.0", optional = true} +qrcode = { version = "0.12", default-features = false, optional = true } # Print QR code # crypto openssl = { git = "https://github.com/sfackler/rust-openssl", optional = true } -foreign-types = { version = "0.3.2", optional = true } mbedtls = { version = "0.9", optional = true } +esp-idf-sys = { version = "0.32", optional = true } +foreign-types = { version = "0.3.2", optional = true } sha2 = { version = "0.10", default-features = false, optional = true } hmac = { version = "0.12", optional = true } pbkdf2 = { version = "0.12", optional = true } @@ -59,21 +60,17 @@ ccm = { version = "0.5", default-features = false, features = ["alloc"], optiona p256 = { version = "0.13.0", default-features = false, features = ["arithmetic", "ecdh", "ecdsa"], optional = true } elliptic-curve = { version = "0.13.2", optional = true } crypto-bigint = { version = "0.4", default-features = false, optional = true } -# Note: requires std +# TODO: requires STD x509-cert = { version = "0.2.0", default-features = false, features = ["pem", "std"], optional = true } # to compute the check digit verhoeff = "1" -# print QR code -qrcode = { version = "0.12", default-features = false } - [target.'cfg(target_os = "macos")'.dependencies] astro-dnssd = "0.3" # MDNS support [target.'cfg(target_os = "linux")'.dependencies] -lazy_static = "1.4.0" libmdns = { version = "0.7.4" } [[example]] From faf5af3e1f51af740ff1e8346468d807aeb2ae6a Mon Sep 17 00:00:00 2001 From: ivmarkov Date: Mon, 24 Apr 2023 09:48:39 +0000 Subject: [PATCH 13/72] no_std printing of QR code (kind of...) --- matter/src/pairing/mod.rs | 1 - matter/src/pairing/qr.rs | 30 +++++++++++++++++++++--------- 2 files changed, 21 insertions(+), 10 deletions(-) diff --git a/matter/src/pairing/mod.rs b/matter/src/pairing/mod.rs index ee5aaffa..96f3105e 100644 --- a/matter/src/pairing/mod.rs +++ b/matter/src/pairing/mod.rs @@ -22,7 +22,6 @@ pub mod qr; pub mod vendor_identifiers; use log::info; -use qrcode::{render::unicode, QrCode, Version}; use verhoeff::Verhoeff; use crate::{ diff --git a/matter/src/pairing/qr.rs b/matter/src/pairing/qr.rs index 70b668d7..33f7b8a4 100644 --- a/matter/src/pairing/qr.rs +++ b/matter/src/pairing/qr.rs @@ -315,15 +315,27 @@ fn estimate_struct_overhead(first_field_size: usize) -> usize { } pub(super) fn print_qr_code(qr_data: &str) { - let needed_version = compute_qr_version(qr_data); - let code = - QrCode::with_version(qr_data, Version::Normal(needed_version), qrcode::EcLevel::M).unwrap(); - let image = code - .render::() - .dark_color(unicode::Dense1x2::Light) - .light_color(unicode::Dense1x2::Dark) - .build(); - info!("\n{}", image); + #[cfg(not(feature = "std"))] + { + info!("\n QR CODE DATA: {}", qr_data); + } + + #[cfg(feature = "std")] + { + use qrcode::{render::unicode, QrCode, Version}; + + let needed_version = compute_qr_version(qr_data); + let code = + QrCode::with_version(qr_data, Version::Normal(needed_version), qrcode::EcLevel::M) + .unwrap(); + let image = code + .render::() + .dark_color(unicode::Dense1x2::Light) + .light_color(unicode::Dense1x2::Dark) + .build(); + + info!("\n{}", image); + } } fn compute_qr_version(qr_data: &str) -> i16 { From 2ea31432d566dcb1fc3b1a338300720f8bc590d7 Mon Sep 17 00:00:00 2001 From: ivmarkov Date: Mon, 24 Apr 2023 11:34:32 +0000 Subject: [PATCH 14/72] On-off example now buildable --- examples/onoff_light/src/lib.rs | 2 +- examples/onoff_light/src/main.rs | 143 +++++++++++++++++++------ matter/Cargo.toml | 2 +- matter/src/core.rs | 2 +- matter/src/data_model/root_endpoint.rs | 10 +- matter/src/interaction_model/core.rs | 32 ++---- matter/src/transport/exchange.rs | 32 ++++-- matter/src/transport/mgr.rs | 33 ++++-- matter/src/transport/proto_ctx.rs | 7 +- matter/src/transport/udp.rs | 2 +- 10 files changed, 178 insertions(+), 87 deletions(-) diff --git a/examples/onoff_light/src/lib.rs b/examples/onoff_light/src/lib.rs index 16264d04..43ca1b11 100644 --- a/examples/onoff_light/src/lib.rs +++ b/examples/onoff_light/src/lib.rs @@ -15,4 +15,4 @@ * limitations under the License. */ -// TODO pub mod dev_att; +pub mod dev_att; diff --git a/examples/onoff_light/src/main.rs b/examples/onoff_light/src/main.rs index b2bc4484..08dece3e 100644 --- a/examples/onoff_light/src/main.rs +++ b/examples/onoff_light/src/main.rs @@ -15,41 +15,114 @@ * limitations under the License. */ -// TODO -// mod dev_att; -// use matter::core::{self, CommissioningData}; -// use matter::data_model::cluster_basic_information::BasicInfoConfig; -// use matter::data_model::device_types::device_type_add_on_off_light; -// use matter::secure_channel::spake2p::VerifierData; +use std::borrow::Borrow; + +use matter::core::{CommissioningData, Matter}; +use matter::data_model::cluster_basic_information::BasicInfoConfig; +use matter::data_model::cluster_on_off; +use matter::data_model::core::DataModel; +use matter::data_model::device_types::DEV_TYPE_ON_OFF_LIGHT; +use matter::data_model::objects::*; +use matter::data_model::root_endpoint; +use matter::data_model::sdm::dev_att::DevAttDataFetcher; +use matter::interaction_model::core::InteractionModel; +use matter::secure_channel::spake2p::VerifierData; +use matter::transport::{ + mgr::RecvAction, mgr::TransportMgr, packet::Packet, packet::MAX_RX_BUF_SIZE, + packet::MAX_TX_BUF_SIZE, udp::UdpListener, +}; + +mod dev_att; fn main() { - // env_logger::init(); - // let comm_data = CommissioningData { - // // TODO: Hard-coded for now - // verifier: VerifierData::new_with_pw(123456), - // discriminator: 250, - // }; - - // // vid/pid should match those in the DAC - // let dev_info = BasicInfoConfig { - // vid: 0xFFF1, - // pid: 0x8000, - // hw_ver: 2, - // sw_ver: 1, - // sw_ver_str: "1".to_string(), - // serial_no: "aabbccdd".to_string(), - // device_name: "OnOff Light".to_string(), - // }; - // let dev_att = Box::new(dev_att::HardCodedDevAtt::new()); - - // let mut matter = core::Matter::new(dev_info, dev_att, comm_data).unwrap(); - // let dm = matter.get_data_model(); - // { - // let mut node = dm.node.write().unwrap(); - // let endpoint = device_type_add_on_off_light(&mut node).unwrap(); - // println!("Added OnOff Light Device type at endpoint id: {}", endpoint); - // println!("Data Model now is: {}", node); - // } - - // matter.start_daemon().unwrap(); + env_logger::init(); + + // vid/pid should match those in the DAC + let dev_info = BasicInfoConfig { + vid: 0xFFF1, + pid: 0x8000, + hw_ver: 2, + sw_ver: 1, + sw_ver_str: "1", + serial_no: "aabbccdd", + device_name: "OnOff Light", + }; + + let mut mdns = matter::sys::LinuxMdns::new().unwrap(); + + let matter = Matter::new_default(&dev_info, &mut mdns); + + let dev_att = dev_att::HardCodedDevAtt::new(); + + matter + .start::<4096>( + CommissioningData { + // TODO: Hard-coded for now + verifier: VerifierData::new_with_pw(123456, *matter.borrow()), + discriminator: 250, + }, + &mut [0; 4096], + ) + .unwrap(); + + let matter = &matter; + let dev_att = &dev_att; + + let mut transport = TransportMgr::new(matter); + + smol::block_on(async move { + let udp = UdpListener::new().await.unwrap(); + + loop { + let mut rx_buf = [0; MAX_RX_BUF_SIZE]; + let mut tx_buf = [0; MAX_TX_BUF_SIZE]; + + let (len, addr) = udp.recv(&mut rx_buf).await.unwrap(); + + let mut rx = Packet::new_rx(&mut rx_buf[..len]); + let mut tx = Packet::new_tx(&mut tx_buf); + + let mut completion = transport.recv(addr, &mut rx, &mut tx); + + while let Some(action) = completion.next_action().unwrap() { + match action { + RecvAction::Send(addr, buf) => { + udp.send(addr, buf).await.unwrap(); + } + RecvAction::Interact(mut ctx) => { + let node = Node { + id: 0, + endpoints: &[ + root_endpoint::endpoint(0), + Endpoint { + id: 1, + device_type: DEV_TYPE_ON_OFF_LIGHT, + clusters: &[cluster_on_off::CLUSTER], + }, + ], + }; + + let mut handler = handler(matter, dev_att); + + let mut im = + InteractionModel(DataModel::new(matter.borrow(), &node, &mut handler)); + + if im.handle(&mut ctx).unwrap() { + if let Some(addr) = ctx.send().unwrap() { + udp.send(addr, ctx.tx.as_slice()).await.unwrap(); + } + } + } + } + } + } + }); +} + +fn handler<'a>(matter: &'a Matter<'a>, dev_att: &'a dyn DevAttDataFetcher) -> impl Handler + 'a { + root_endpoint::handler(0, dev_att, matter).chain( + 1, + cluster_on_off::ID, + cluster_on_off::OnOffCluster::new(*matter.borrow()), + ) } diff --git a/matter/Cargo.toml b/matter/Cargo.toml index d649645a..db10e775 100644 --- a/matter/Cargo.toml +++ b/matter/Cargo.toml @@ -15,7 +15,7 @@ name = "matter" path = "src/lib.rs" [features] -default = ["std", "crypto_mbedtls"] +default = ["std", "crypto_mbedtls", "nightly"] std = ["alloc", "env_logger", "chrono", "rand", "qrcode", "smol"] alloc = [] nightly = [] diff --git a/matter/src/core.rs b/matter/src/core.rs index 24e90cf0..41c53071 100644 --- a/matter/src/core.rs +++ b/matter/src/core.rs @@ -99,7 +99,7 @@ impl<'a> Matter<'a> { } pub fn start( - &mut self, + &self, dev_comm: CommissioningData, buf: &mut [u8], ) -> Result<(), Error> { diff --git a/matter/src/data_model/root_endpoint.rs b/matter/src/data_model/root_endpoint.rs index ebcbb140..341b3de1 100644 --- a/matter/src/data_model/root_endpoint.rs +++ b/matter/src/data_model/root_endpoint.rs @@ -12,7 +12,7 @@ use crate::{ use super::{ cluster_basic_information::{self, BasicInfoCluster, BasicInfoConfig}, - objects::{Cluster, EmptyHandler}, + objects::{Cluster, EmptyHandler, Endpoint, EndptId}, sdm::{ admin_commissioning::{self, AdminCommCluster}, dev_att::DevAttDataFetcher, @@ -47,6 +47,14 @@ pub const CLUSTERS: [Cluster<'static>; 7] = [ access_control::CLUSTER, ]; +pub fn endpoint(id: EndptId) -> Endpoint<'static> { + Endpoint { + id, + device_type: super::device_types::DEV_TYPE_ROOT_NODE, + clusters: &CLUSTERS, + } +} + pub fn handler<'a>( endpoint_id: u16, dev_att: &'a dyn DevAttDataFetcher, diff --git a/matter/src/interaction_model/core.rs b/matter/src/interaction_model/core.rs index 162d64c6..fd9e130f 100644 --- a/matter/src/interaction_model/core.rs +++ b/matter/src/interaction_model/core.rs @@ -710,14 +710,14 @@ impl ResumeSubscribeReq { } pub trait InteractionHandler { - fn handle<'a>(&mut self, ctx: &'a mut ProtoCtx) -> Result, Error>; + fn handle(&mut self, ctx: &mut ProtoCtx) -> Result; } impl InteractionHandler for &mut T where T: InteractionHandler, { - fn handle<'a>(&mut self, ctx: &'a mut ProtoCtx) -> Result, Error> { + fn handle(&mut self, ctx: &mut ProtoCtx) -> Result { (**self).handle(ctx) } } @@ -728,7 +728,7 @@ impl InteractionModel where T: DataHandler, { - pub fn handle<'a>(&mut self, ctx: &'a mut ProtoCtx) -> Result, Error> { + pub fn handle(&mut self, ctx: &mut ProtoCtx) -> Result { let mut transaction = Transaction::new(&mut ctx.exch_ctx); let reply = @@ -738,7 +738,7 @@ where true }; - Ok(reply.then_some(ctx.tx.as_slice())) + Ok(reply) } } @@ -747,10 +747,7 @@ impl InteractionModel where T: crate::data_model::core::asynch::AsyncDataHandler, { - pub async fn handle_async<'a>( - &mut self, - ctx: &'a mut ProtoCtx<'_, '_>, - ) -> Result, Error> { + pub async fn handle_async<'a>(&mut self, ctx: &mut ProtoCtx<'_, '_>) -> Result { let mut transaction = Transaction::new(&mut ctx.exch_ctx); let reply = @@ -760,7 +757,7 @@ where true }; - Ok(reply.then_some(ctx.tx.as_slice())) + Ok(reply) } } @@ -768,7 +765,7 @@ impl InteractionHandler for InteractionModel where T: DataHandler, { - fn handle<'a>(&mut self, ctx: &'a mut ProtoCtx) -> Result, Error> { + fn handle(&mut self, ctx: &mut ProtoCtx) -> Result { InteractionModel::handle(self, ctx) } } @@ -782,20 +779,14 @@ pub mod asynch { use super::InteractionModel; pub trait AsyncInteractionHandler { - async fn handle<'a>( - &mut self, - ctx: &'a mut ProtoCtx<'_, '_>, - ) -> Result, Error>; + async fn handle(&mut self, ctx: &mut ProtoCtx<'_, '_>) -> Result; } impl AsyncInteractionHandler for &mut T where T: AsyncInteractionHandler, { - async fn handle<'a>( - &mut self, - ctx: &'a mut ProtoCtx<'_, '_>, - ) -> Result, Error> { + async fn handle(&mut self, ctx: &mut ProtoCtx<'_, '_>) -> Result { (**self).handle(ctx).await } } @@ -804,10 +795,7 @@ pub mod asynch { where T: AsyncDataHandler, { - async fn handle<'a>( - &mut self, - ctx: &'a mut ProtoCtx<'_, '_>, - ) -> Result, Error> { + async fn handle(&mut self, ctx: &mut ProtoCtx<'_, '_>) -> Result { InteractionModel::handle_async(self, ctx).await } } diff --git a/matter/src/transport/exchange.rs b/matter/src/transport/exchange.rs index 053bf793..8e3df105 100644 --- a/matter/src/transport/exchange.rs +++ b/matter/src/transport/exchange.rs @@ -31,7 +31,10 @@ use crate::utils::rand::Rand; use heapless::LinearMap; use super::session::CloneData; -use super::{mrp::ReliableMessage, packet::Packet, session::SessionHandle, session::SessionMgr}; +use super::{ + mrp::ReliableMessage, network::Address, packet::Packet, session::SessionHandle, + session::SessionMgr, +}; pub struct ExchangeCtx<'a> { pub exch: &'a mut Exchange, @@ -40,7 +43,7 @@ pub struct ExchangeCtx<'a> { } impl<'a> ExchangeCtx<'a> { - pub fn send(&mut self, tx: &mut Packet) -> Result<(), Error> { + pub fn send(&mut self, tx: &mut Packet) -> Result, Error> { self.exch.send(tx, &mut self.sess) } } @@ -198,10 +201,10 @@ impl Exchange { &mut self, tx: &mut Packet, session: &mut SessionHandle, - ) -> Result<(), Error> { + ) -> Result, Error> { if self.state == State::Terminate { info!("Skipping tx for terminated exchange {}", self.id); - return Ok(()); + return Ok(None); } trace!("payload: {:x?}", tx.as_mut_slice()); @@ -219,7 +222,9 @@ impl Exchange { session.pre_send(tx)?; self.mrp.pre_send(tx)?; - session.send(tx) + session.send(tx)?; + + Ok(Some(session.get_peer_addr())) } } @@ -354,11 +359,13 @@ impl ExchangeMgr { } } - pub fn send(&mut self, exch_id: u16, tx: &mut Packet) -> Result<(), Error> { + pub fn send(&mut self, exch_id: u16, tx: &mut Packet) -> Result { let exchange = ExchangeMgr::_get_with_id(&mut self.exchanges, exch_id).ok_or(Error::NoExchange)?; let mut session = self.sess_mgr.get_session_handle(exchange.sess_idx); - exchange.send(tx, &mut session) + exchange.send(tx, &mut session)?; + + Ok(session.get_peer_addr()) } pub fn purge(&mut self) { @@ -381,7 +388,7 @@ impl ExchangeMgr { .map(|(exch_id, _)| *exch_id) } - pub fn evict_session(&mut self, tx: &mut Packet) -> Result { + pub fn evict_session(&mut self, tx: &mut Packet) -> Result, Error> { if let Some(index) = self.sess_mgr.get_session_for_eviction() { info!("Sessions full, vacating session with index: {}", index); // If we enter here, we have an LRU session that needs to be reclaimed @@ -423,11 +430,14 @@ impl ExchangeMgr { // Remove from exchange list self.exchanges.remove(&exch_id); } + + let addr = session.get_peer_addr(); + self.sess_mgr.remove(index); - Ok(true) + Ok(Some(addr)) } else { - Ok(false) + Ok(None) } } @@ -561,7 +571,7 @@ mod tests { let mut buf = [0; MAX_TX_BUF_SIZE]; let tx = &mut Packet::new_tx(&mut buf); let evicted = mgr.evict_session(tx).unwrap(); - assert!(evicted); + assert!(evicted.is_some()); let session = mgr .add_session(&get_clone_data(new_peer_sess_id, new_local_sess_id)) diff --git a/matter/src/transport/mgr.rs b/matter/src/transport/mgr.rs index 16925b2a..c4bc2e7b 100644 --- a/matter/src/transport/mgr.rs +++ b/matter/src/transport/mgr.rs @@ -32,6 +32,7 @@ use crate::transport::{exchange, packet::Packet}; use crate::utils::epoch::{Epoch, UtcCalendar}; use crate::utils::rand::Rand; +use super::network::Address; use super::proto_ctx::ProtoCtx; use super::session::CloneData; @@ -45,12 +46,13 @@ enum RecvState { } pub enum RecvAction<'r, 'p> { - Send(&'r [u8]), + Send(Address, &'r [u8]), Interact(ProtoCtx<'r, 'p>), } pub struct RecvCompletion<'r, 'a, 'p> { mgr: &'r mut TransportMgr<'a>, + addr: Address, // TODO: Not used yet rx: &'r mut Packet<'p>, tx: &'r mut Packet<'p>, state: RecvState, @@ -90,9 +92,10 @@ impl<'r, 'a, 'p> RecvCompletion<'r, 'a, 'p> { self.state = RecvState::Ack; } - if reply { - proto_ctx.send()?; - Ok(Some(Some(RecvAction::Send(self.tx.as_slice())))) + let addr = if reply { proto_ctx.send()? } else { None }; + + if let Some(addr) = addr { + Ok(Some(Some(RecvAction::Send(addr, self.tx.as_slice())))) } else { Ok(None) } @@ -125,14 +128,22 @@ impl<'r, 'a, 'p> RecvCompletion<'r, 'a, 'p> { Err(err) => Err(err), }, RecvState::EvictSession => { - self.mgr.exch_mgr.evict_session(self.tx)?; + let addr = self.mgr.exch_mgr.evict_session(self.tx)?; self.state = RecvState::OpenExchange; - Ok(Some(Some(RecvAction::Send(self.tx.as_slice())))) + if let Some(addr) = addr { + Ok(Some(Some(RecvAction::Send(addr, self.tx.as_slice())))) + } else { + Ok(None) + } } RecvState::EvictSession2(clone_data) => { - self.mgr.exch_mgr.evict_session(self.tx)?; + let addr = self.mgr.exch_mgr.evict_session(self.tx)?; self.state = RecvState::AddSession(clone_data); - Ok(Some(Some(RecvAction::Send(self.tx.as_slice())))) + if let Some(addr) = addr { + Ok(Some(Some(RecvAction::Send(addr, self.tx.as_slice())))) + } else { + Ok(None) + } } RecvState::Ack => { if let Some(exch_id) = self.mgr.exch_mgr.pending_ack() { @@ -140,8 +151,8 @@ impl<'r, 'a, 'p> RecvCompletion<'r, 'a, 'p> { ReliableMessage::prepare_ack(exch_id, self.tx); - self.mgr.exch_mgr.send(exch_id, self.tx)?; - Ok(Some(Some(RecvAction::Send(self.tx.as_slice())))) + let addr = self.mgr.exch_mgr.send(exch_id, self.tx)?; + Ok(Some(Some(RecvAction::Send(addr, self.tx.as_slice())))) } else { Ok(Some(None)) } @@ -220,11 +231,13 @@ impl<'a> TransportMgr<'a> { pub fn recv<'r, 'p>( &'r mut self, + addr: Address, rx: &'r mut Packet<'p>, tx: &'r mut Packet<'p>, ) -> RecvCompletion<'r, 'a, 'p> { RecvCompletion { mgr: self, + addr, rx, tx, state: RecvState::New, diff --git a/matter/src/transport/proto_ctx.rs b/matter/src/transport/proto_ctx.rs index c4bf7f38..c7b95db3 100644 --- a/matter/src/transport/proto_ctx.rs +++ b/matter/src/transport/proto_ctx.rs @@ -18,6 +18,7 @@ use crate::error::Error; use super::exchange::ExchangeCtx; +use super::network::Address; use super::packet::Packet; /// This is the context in which a receive packet is being processed @@ -35,9 +36,7 @@ impl<'a, 'b> ProtoCtx<'a, 'b> { Self { exch_ctx, rx, tx } } - pub fn send(&mut self) -> Result<&[u8], Error> { - self.exch_ctx.exch.send(self.tx, &mut self.exch_ctx.sess)?; - - Ok(self.tx.as_slice()) + pub fn send(&mut self) -> Result, Error> { + self.exch_ctx.exch.send(self.tx, &mut self.exch_ctx.sess) } } diff --git a/matter/src/transport/udp.rs b/matter/src/transport/udp.rs index 6f7a2651..7a82fb33 100644 --- a/matter/src/transport/udp.rs +++ b/matter/src/transport/udp.rs @@ -48,7 +48,7 @@ impl UdpListener { Ok((size, Address::Udp(addr))) } - pub async fn send(&self, out_buf: &[u8], addr: Address) -> Result { + pub async fn send(&self, addr: Address, out_buf: &[u8]) -> Result { match addr { Address::Udp(addr) => self.socket.send_to(out_buf, addr).await.map_err(|e| { info!("Error on the network: {:?}", e); From eb3c9cdfb1f78b2d35061bb9938744809ddf5085 Mon Sep 17 00:00:00 2001 From: ivmarkov Date: Mon, 24 Apr 2023 19:27:17 +0000 Subject: [PATCH 15/72] Cleanup a bit the mDns story --- examples/onoff_light/src/main.rs | 3 +- matter/Cargo.toml | 15 +- matter/src/mdns.rs | 403 +++++++++++++++++++++++++-- matter/src/pairing/qr.rs | 5 +- matter/src/secure_channel/spake2p.rs | 7 +- matter/src/sys/mod.rs | 12 - matter/src/sys/sys_linux.rs | 116 -------- matter/src/sys/sys_macos.rs | 109 -------- matter/src/transport/udp.rs | 41 ++- 9 files changed, 428 insertions(+), 283 deletions(-) delete mode 100644 matter/src/sys/sys_linux.rs delete mode 100644 matter/src/sys/sys_macos.rs diff --git a/examples/onoff_light/src/main.rs b/examples/onoff_light/src/main.rs index 08dece3e..d4801887 100644 --- a/examples/onoff_light/src/main.rs +++ b/examples/onoff_light/src/main.rs @@ -48,7 +48,8 @@ fn main() { device_name: "OnOff Light", }; - let mut mdns = matter::sys::LinuxMdns::new().unwrap(); + //let mut mdns = matter::mdns::bonjour::BonjourMdns::new().unwrap(); + let mut mdns = matter::mdns::libmdns::LibMdns::new().unwrap(); let matter = Matter::new_default(&dev_info, &mut mdns); diff --git a/matter/Cargo.toml b/matter/Cargo.toml index db10e775..85d113dc 100644 --- a/matter/Cargo.toml +++ b/matter/Cargo.toml @@ -16,7 +16,7 @@ path = "src/lib.rs" [features] default = ["std", "crypto_mbedtls", "nightly"] -std = ["alloc", "env_logger", "chrono", "rand", "qrcode", "smol"] +std = ["alloc", "env_logger", "chrono", "rand", "qrcode", "libmdns", "simple-mdns", "simple-dns", "smol"] alloc = [] nightly = [] crypto_openssl = ["openssl", "foreign-types", "hmac", "sha2"] @@ -43,8 +43,12 @@ colored = "2.0.0" # TODO: Requires STD env_logger = { version = "0.10.0", default-features = false, optional = true } chrono = { version = "0.4.23", optional = true, default-features = false, features = ["clock", "std"] } rand = { version = "0.8.5", optional = true } -smol = { version = "1.3.0", optional = true} qrcode = { version = "0.12", default-features = false, optional = true } # Print QR code +libmdns = { version = "0.7", optional = true } +simple-mdns = { version = "0.4", features = ["sync"], optional = true } +simple-dns = { version = "0.5", optional = true } +astro-dnssd = { version = "0.3", optional = true } +smol = { version = "1.3.0", optional = true} # crypto openssl = { git = "https://github.com/sfackler/rust-openssl", optional = true } @@ -66,13 +70,6 @@ x509-cert = { version = "0.2.0", default-features = false, features = ["pem", "s # to compute the check digit verhoeff = "1" -[target.'cfg(target_os = "macos")'.dependencies] -astro-dnssd = "0.3" - -# MDNS support -[target.'cfg(target_os = "linux")'.dependencies] -libmdns = { version = "0.7.4" } - [[example]] name = "onoff_light" path = "../examples/onoff_light/src/main.rs" diff --git a/matter/src/mdns.rs b/matter/src/mdns.rs index 71be231f..1f218669 100644 --- a/matter/src/mdns.rs +++ b/matter/src/mdns.rs @@ -25,6 +25,7 @@ pub trait Mdns { name: &str, service_type: &str, port: u16, + service_subtypes: &[&str], txt_kvs: &[(&str, &str)], ) -> Result<(), Error>; @@ -40,9 +41,10 @@ where name: &str, service_type: &str, port: u16, + service_subtypes: &[&str], txt_kvs: &[(&str, &str)], ) -> Result<(), Error> { - (**self).add(name, service_type, port, txt_kvs) + (**self).add(name, service_type, port, service_subtypes, txt_kvs) } fn remove(&mut self, name: &str, service_type: &str, port: u16) -> Result<(), Error> { @@ -58,6 +60,7 @@ impl Mdns for DummyMdns { _name: &str, _service_type: &str, _port: u16, + _service_subtypes: &[&str], _txt_kvs: &[(&str, &str)], ) -> Result<(), Error> { Ok(()) @@ -112,11 +115,12 @@ impl<'a> MdnsMgr<'a> { #[allow(clippy::needless_pass_by_value)] pub fn publish_service(&mut self, name: &str, mode: ServiceMode) -> Result<(), Error> { match mode { - ServiceMode::Commissioned => self.mdns.add(name, "_matter._tcp", self.matter_port, &[]), + ServiceMode::Commissioned => { + self.mdns + .add(name, "_matter._tcp", self.matter_port, &[], &[]) + } ServiceMode::Commissionable(discriminator) => { let discriminator_str = Self::get_discriminator_str(discriminator); - - let serv_type = self.get_service_type(discriminator); let vp = self.get_vp(); let txt_kvs = [ @@ -129,7 +133,17 @@ impl<'a> MdnsMgr<'a> { ("PH", "33"), /* Pairing Hint */ ("PI", ""), /* Pairing Instruction */ ]; - self.mdns.add(name, &serv_type, self.matter_port, &txt_kvs) + + self.mdns.add( + name, + "_matter._udp", + self.matter_port, + &[ + &self.get_long_service_subtype(discriminator), + &self.get_short_service_type(discriminator), + ], + &txt_kvs, + ) } } } @@ -137,28 +151,32 @@ impl<'a> MdnsMgr<'a> { pub fn unpublish_service(&mut self, name: &str, mode: ServiceMode) -> Result<(), Error> { match mode { ServiceMode::Commissioned => self.mdns.remove(name, "_matter._tcp", self.matter_port), - ServiceMode::Commissionable(discriminator) => { - let serv_type = self.get_service_type(discriminator); - - self.mdns.remove(name, &serv_type, self.matter_port) + ServiceMode::Commissionable(_) => { + self.mdns.remove(name, "_matter._udp", self.matter_port) } } } - fn get_service_type(&self, discriminator: u16) -> heapless::String<32> { - let short = Self::compute_short_discriminator(discriminator); + fn get_long_service_subtype(&self, discriminator: u16) -> heapless::String<32> { let mut serv_type = heapless::String::new(); + write!(&mut serv_type, "_L{}", discriminator).unwrap(); - write!( - &mut serv_type, - "_matterc._udp,_S{},_L{}", - short, discriminator - ) - .unwrap(); + serv_type + } + + fn get_short_service_type(&self, discriminator: u16) -> heapless::String<32> { + let short = Self::compute_short_discriminator(discriminator); + + let mut serv_type = heapless::String::new(); + write!(&mut serv_type, "_S{}", short).unwrap(); serv_type } + fn get_discriminator_str(discriminator: u16) -> heapless::String<5> { + discriminator.into() + } + fn get_vp(&self) -> heapless::String<11> { let mut vp = heapless::String::new(); @@ -167,10 +185,6 @@ impl<'a> MdnsMgr<'a> { vp } - fn get_discriminator_str(discriminator: u16) -> heapless::String<5> { - discriminator.into() - } - fn compute_short_discriminator(discriminator: u16) -> u16 { const SHORT_DISCRIMINATOR_MASK: u16 = 0xF00; const SHORT_DISCRIMINATOR_SHIFT: u16 = 8; @@ -179,6 +193,353 @@ impl<'a> MdnsMgr<'a> { } } +#[cfg(all(feature = "std", feature = "bonjour"))] +pub mod bonjour { + use std::collections::HashMap; + + use super::Mdns; + use crate::error::Error; + use astro_dnssd::{DNSServiceBuilder, RegisteredDnsService}; + use log::info; + + #[derive(Debug, Clone, Eq, PartialEq, Hash)] + pub struct ServiceId { + name: String, + service_type: String, + port: u16, + } + + pub struct BonjourMdns { + services: HashMap, + } + + impl BonjourMdns { + pub fn new() -> Result { + Ok(Self { + services: HashMap::new(), + }) + } + + pub fn add( + &mut self, + name: &str, + service_type: &str, + port: u16, + service_subtypes: &[&str], + txt_kvs: &[(&str, &str)], + ) -> Result<(), Error> { + info!( + "Registering mDNS service {}/{}/{}", + name, service_type, port + ); + + let _ = self.remove(name, service_type, port); + + let composite_service_type = if !service_subtypes.is_empty() { + format!("{}{}", service_type, service_subtypes.join(",")) + } else { + service_type + }; + + let mut builder = DNSServiceBuilder::new(composite_service_type, port).with_name(name); + + for kvs in txt_kvs { + info!("mDNS TXT key {} val {}", kvs.0, kvs.1); + builder = builder.with_key_value(kvs.0.to_string(), kvs.1.to_string()); + } + + let service = builder.register().map_err(|_| Error::MdnsError)?; + + self.services.insert( + ServiceId { + name: name.into(), + service_type: service_type.into(), + port, + }, + service, + ); + + Ok(()) + } + + pub fn remove(&mut self, name: &str, service_type: &str, port: u16) -> Result<(), Error> { + let id = ServiceId { + name: name.into(), + service_type: service_type.into(), + port, + }; + + if self.services.remove(&id).is_some() { + info!( + "Deregistering mDNS service {}/{}/{}", + name, service_type, port + ); + } + + Ok(()) + } + } + + impl Mdns for BonjourMdns { + fn add( + &mut self, + name: &str, + service_type: &str, + port: u16, + service_subtypes: &[&str], + txt_kvs: &[(&str, &str)], + ) -> Result<(), Error> { + BonjourMdns::add(self, name, service_type, port, service_subtypes, txt_kvs) + } + + fn remove(&mut self, name: &str, service_type: &str, port: u16) -> Result<(), Error> { + BonjourMdns::remove(self, name, service_type, port) + } + } +} + +#[cfg(feature = "std")] +pub mod libmdns { + use super::Mdns; + use crate::error::Error; + use libmdns::{Responder, Service}; + use log::info; + use std::collections::HashMap; + use std::vec::Vec; + + #[derive(Debug, Clone, Eq, PartialEq, Hash)] + pub struct ServiceId { + name: String, + service_type: String, + port: u16, + } + + pub struct LibMdns { + responder: Responder, + services: HashMap, + } + + impl LibMdns { + pub fn new() -> Result { + let responder = Responder::new()?; + + Ok(Self { + responder, + services: HashMap::new(), + }) + } + + pub fn add( + &mut self, + name: &str, + service_type: &str, + port: u16, + txt_kvs: &[(&str, &str)], + ) -> Result<(), Error> { + info!( + "Registering mDNS service {}/{}/{}", + name, service_type, port + ); + + let _ = self.remove(name, service_type, port); + + let mut properties = Vec::new(); + for kvs in txt_kvs { + info!("mDNS TXT key {} val {}", kvs.0, kvs.1); + properties.push(format!("{}={}", kvs.0, kvs.1)); + } + let properties: Vec<&str> = properties.iter().map(|entry| entry.as_str()).collect(); + + let service = self.responder.register( + service_type.to_owned(), + name.to_owned(), + port, + &properties, + ); + + self.services.insert( + ServiceId { + name: name.into(), + service_type: service_type.into(), + port, + }, + service, + ); + + Ok(()) + } + + pub fn remove(&mut self, name: &str, service_type: &str, port: u16) -> Result<(), Error> { + let id = ServiceId { + name: name.into(), + service_type: service_type.into(), + port, + }; + + if self.services.remove(&id).is_some() { + info!( + "Deregistering mDNS service {}/{}/{}", + name, service_type, port + ); + } + + Ok(()) + } + } + + impl Mdns for LibMdns { + fn add( + &mut self, + name: &str, + service_type: &str, + port: u16, + _service_subtypes: &[&str], + txt_kvs: &[(&str, &str)], + ) -> Result<(), Error> { + LibMdns::add(self, name, service_type, port, txt_kvs) + } + + fn remove(&mut self, name: &str, service_type: &str, port: u16) -> Result<(), Error> { + LibMdns::remove(self, name, service_type, port) + } + } +} + +// #[cfg(feature = "std")] +// pub mod simplemdns { +// use std::net::Ipv4Addr; + +// use crate::error::Error; +// use super::Mdns; +// use log::info; +// use simple_dns::{ +// rdata::{RData, A, SRV, TXT, PTR}, +// CharacterString, Name, ResourceRecord, CLASS, +// }; +// use simple_mdns::sync_discovery::SimpleMdnsResponder; + +// #[derive(Debug, Clone, Eq, PartialEq, Hash)] +// pub struct ServiceId { +// name: String, +// service_type: String, +// port: u16, +// } + +// pub struct SimpleMdns { +// responder: SimpleMdnsResponder, +// } + +// impl SimpleMdns { +// pub fn new() -> Result { +// Ok(Self { +// responder: Default::default(), +// }) +// } + +// pub fn add( +// &mut self, +// name: &str, +// service_type: &str, +// port: u16, +// txt_kvs: &[(&str, &str)], +// ) -> Result<(), Error> { +// info!( +// "Registering mDNS service {}/{}/{}", +// name, service_type, port +// ); + +// let _ = self.remove(name, service_type, port); + +// let mut txt = TXT::new(); +// for kvs in txt_kvs { +// info!("mDNS TXT key {} val {}", kvs.0, kvs.1); + +// let string = format!("{}={}", kvs.0, kvs.1); +// txt.add_char_string( +// CharacterString::new(string.as_bytes()) +// .unwrap() +// .into_owned(), +// ); +// } + +// let name = Name::new_unchecked(name).into_owned(); +// let service_type = Name::new_unchecked(service_type).into_owned(); + +// self.responder.add_resource(ResourceRecord::new( +// name.clone(), +// CLASS::IN, +// 10, +// RData::A(A { +// address: Ipv4Addr::new(192, 168, 10, 189).into(), +// }), +// )); + +// self.responder.add_resource(ResourceRecord::new( +// name.clone(), +// CLASS::IN, +// 10, +// RData::SRV(SRV { +// port: port, +// priority: 0, +// weight: 0, +// target: service_type.clone(), +// }), +// )); + +// self.responder.add_resource(ResourceRecord::new( +// srv_name.clone(), +// CLASS::IN, +// 10, +// RData::PTR(PTR(srv_name.clone()), +// ))); + +// self.responder.add_resource(ResourceRecord::new( +// srv_name, +// CLASS::IN, +// 10, +// RData::TXT(txt), +// )); + +// Ok(()) +// } + +// pub fn remove(&mut self, name: &str, service_type: &str, port: u16) -> Result<(), Error> { +// // TODO +// // let id = ServiceId { +// // name: name.into(), +// // service_type: service_type.into(), +// // port, +// // }; + +// // if self.responder.remove_resource_record(resource).remove(&id).is_some() { +// // info!( +// // "Deregistering mDNS service {}/{}/{}", +// // name, service_type, port +// // ); +// // } + +// Ok(()) +// } +// } + +// impl Mdns for SimpleMdns { +// fn add( +// &mut self, +// name: &str, +// service_type: &str, +// port: u16, +// _service_subtypes: &[&str], +// txt_kvs: &[(&str, &str)], +// ) -> Result<(), Error> { +// SimpleMdns::add(self, name, service_type, port, txt_kvs) +// } + +// fn remove(&mut self, name: &str, service_type: &str, port: u16) -> Result<(), Error> { +// SimpleMdns::remove(self, name, service_type, port) +// } +// } +// } + #[cfg(test)] mod tests { use super::*; diff --git a/matter/src/pairing/qr.rs b/matter/src/pairing/qr.rs index 33f7b8a4..e99d909c 100644 --- a/matter/src/pairing/qr.rs +++ b/matter/src/pairing/qr.rs @@ -315,10 +315,7 @@ fn estimate_struct_overhead(first_field_size: usize) -> usize { } pub(super) fn print_qr_code(qr_data: &str) { - #[cfg(not(feature = "std"))] - { - info!("\n QR CODE DATA: {}", qr_data); - } + info!("QR Code: {}", qr_data); #[cfg(feature = "std")] { diff --git a/matter/src/secure_channel/spake2p.rs b/matter/src/secure_channel/spake2p.rs index ba948f51..8a4b794c 100644 --- a/matter/src/secure_channel/spake2p.rs +++ b/matter/src/secure_channel/spake2p.rs @@ -17,7 +17,6 @@ use crate::{ crypto::{self, HmacSha256}, - sys, utils::rand::Rand, }; use byteorder::{ByteOrder, LittleEndian}; @@ -31,7 +30,7 @@ use crate::{ use super::{common::SCStatusCodes, crypto::CryptoSpake2}; -// This file handle Spake2+ specific instructions. In itself, this file is +// This file handles Spake2+ specific instructions. In itself, this file is // independent from the BigNum and EC operations that are typically required // Spake2+. We use the CryptoSpake2 trait object that allows us to abstract // out the specific implementations. @@ -39,6 +38,8 @@ use super::{common::SCStatusCodes, crypto::CryptoSpake2}; // In the case of the verifier, we don't actually release the Ke until we // validate that the cA is confirmed. +pub const SPAKE2_ITERATION_COUNT: u32 = 2000; + #[derive(PartialEq, Copy, Clone, Debug)] pub enum Spake2VerifierState { // Initialised - w0, L are set @@ -104,7 +105,7 @@ impl VerifierData { pub fn new_with_pw(pw: u32, rand: Rand) -> Self { let mut s = Self { salt: [0; MAX_SALT_SIZE_BYTES], - count: sys::SPAKE2_ITERATION_COUNT, + count: SPAKE2_ITERATION_COUNT, data: VerifierOption::Password(pw), }; rand(&mut s.salt); diff --git a/matter/src/sys/mod.rs b/matter/src/sys/mod.rs index 0ce65e71..a80b8759 100644 --- a/matter/src/sys/mod.rs +++ b/matter/src/sys/mod.rs @@ -15,18 +15,6 @@ * limitations under the License. */ -#[cfg(all(feature = "std", target_os = "macos"))] -mod sys_macos; -#[cfg(all(feature = "std", target_os = "macos"))] -pub use self::sys_macos::*; - -#[cfg(all(feature = "std", target_os = "linux"))] -mod sys_linux; -#[cfg(all(feature = "std", target_os = "linux"))] -pub use self::sys_linux::*; - -pub const SPAKE2_ITERATION_COUNT: u32 = 2000; - // The Packet Pool that is allocated from. POSIX systems can use // higher values unlike embedded systems pub const MAX_PACKET_POOL_SIZE: usize = 25; diff --git a/matter/src/sys/sys_linux.rs b/matter/src/sys/sys_linux.rs deleted file mode 100644 index 0d3f0dc2..00000000 --- a/matter/src/sys/sys_linux.rs +++ /dev/null @@ -1,116 +0,0 @@ -/* - * - * Copyright (c) 2020-2022 Project CHIP Authors - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -use crate::error::Error; -use crate::mdns::Mdns; -use libmdns::{Responder, Service}; -use log::info; -use std::collections::HashMap; -use std::vec::Vec; - -#[derive(Debug, Clone, Eq, PartialEq, Hash)] -pub struct ServiceId { - name: String, - service_type: String, - port: u16, -} - -pub struct LinuxMdns { - responder: Responder, - services: HashMap, -} - -impl LinuxMdns { - pub fn new() -> Result { - let responder = Responder::new()?; - - Ok(Self { - responder, - services: HashMap::new(), - }) - } - - pub fn add( - &mut self, - name: &str, - service_type: &str, - port: u16, - txt_kvs: &[(&str, &str)], - ) -> Result<(), Error> { - info!( - "Registering mDNS service {}/{}/{}", - name, service_type, port - ); - - let _ = self.remove(name, service_type, port); - - let mut properties = Vec::new(); - for kvs in txt_kvs { - info!("mDNS TXT key {} val {}", kvs.0, kvs.1); - properties.push(format!("{}={}", kvs.0, kvs.1)); - } - let properties: Vec<&str> = properties.iter().map(|entry| entry.as_str()).collect(); - - let service = - self.responder - .register(service_type.to_owned(), name.to_owned(), port, &properties); - - self.services.insert( - ServiceId { - name: name.into(), - service_type: service_type.into(), - port, - }, - service, - ); - - Ok(()) - } - - pub fn remove(&mut self, name: &str, service_type: &str, port: u16) -> Result<(), Error> { - let id = ServiceId { - name: name.into(), - service_type: service_type.into(), - port, - }; - - if self.services.remove(&id).is_some() { - info!( - "Deregistering mDNS service {}/{}/{}", - name, service_type, port - ); - } - - Ok(()) - } -} - -impl Mdns for LinuxMdns { - fn add( - &mut self, - name: &str, - service_type: &str, - port: u16, - txt_kvs: &[(&str, &str)], - ) -> Result<(), Error> { - LinuxMdns::add(self, name, service_type, port, txt_kvs) - } - - fn remove(&mut self, name: &str, service_type: &str, port: u16) -> Result<(), Error> { - LinuxMdns::remove(self, name, service_type, port) - } -} diff --git a/matter/src/sys/sys_macos.rs b/matter/src/sys/sys_macos.rs deleted file mode 100644 index d8ffe3ac..00000000 --- a/matter/src/sys/sys_macos.rs +++ /dev/null @@ -1,109 +0,0 @@ -/* - * - * Copyright (c) 2020-2022 Project CHIP Authors - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -use std::collections::HashMap; - -use crate::{error::Error, mdns::Mdns}; -use astro_dnssd::{DNSServiceBuilder, RegisteredDnsService}; -use log::info; - -#[derive(Debug, Clone, Eq, PartialEq, Hash)] -pub struct ServiceId { - name: String, - service_type: String, - port: u16, -} - -pub struct MacOsMdns { - services: HashMap, -} - -impl MacOsMdns { - pub fn new() -> Result { - Ok(Self { - services: HashMap::new(), - }) - } - - pub fn add( - &mut self, - name: &str, - service_type: &str, - port: u16, - txt_kvs: &[(&str, &str)], - ) -> Result<(), Error> { - info!( - "Registering mDNS service {}/{}/{}", - name, service_type, port - ); - - let _ = self.remove(name, service_type, port); - - let mut builder = DNSServiceBuilder::new(service_type, port).with_name(name); - - for kvs in txt_kvs { - info!("mDNS TXT key {} val {}", kvs.0, kvs.1); - builder = builder.with_key_value(kvs.0.to_string(), kvs.1.to_string()); - } - - let service = builder.register().map_err(|_| Error::MdnsError)?; - - self.services.insert( - ServiceId { - name: name.into(), - service_type: service_type.into(), - port, - }, - service, - ); - - Ok(()) - } - - pub fn remove(&mut self, name: &str, service_type: &str, port: u16) -> Result<(), Error> { - let id = ServiceId { - name: name.into(), - service_type: service_type.into(), - port, - }; - - if self.services.remove(&id).is_some() { - info!( - "Deregistering mDNS service {}/{}/{}", - name, service_type, port - ); - } - - Ok(()) - } -} - -impl Mdns for MacOsMdns { - fn add( - &mut self, - name: &str, - service_type: &str, - port: u16, - txt_kvs: &[(&str, &str)], - ) -> Result<(), Error> { - MacOsMdns::add(self, name, service_type, port, txt_kvs) - } - - fn remove(&mut self, name: &str, service_type: &str, port: u16) -> Result<(), Error> { - MacOsMdns::remove(self, name, service_type, port) - } -} diff --git a/matter/src/transport/udp.rs b/matter/src/transport/udp.rs index 7a82fb33..b2dd1dc1 100644 --- a/matter/src/transport/udp.rs +++ b/matter/src/transport/udp.rs @@ -16,7 +16,7 @@ */ use crate::error::*; -use log::info; +use log::{info, warn}; use smol::net::{Ipv6Addr, UdpSocket}; use super::network::Address; @@ -35,25 +35,50 @@ pub const MATTER_PORT: u16 = 5540; impl UdpListener { pub async fn new() -> Result { - Ok(UdpListener { + let listener = UdpListener { socket: UdpSocket::bind((Ipv6Addr::UNSPECIFIED, MATTER_PORT)).await?, - }) + }; + + info!( + "Listening on {:?} port {}", + Ipv6Addr::UNSPECIFIED, + MATTER_PORT + ); + + Ok(listener) } pub async fn recv(&self, in_buf: &mut [u8]) -> Result<(usize, Address), Error> { + info!("Waiting for incoming packets"); + let (size, addr) = self.socket.recv_from(in_buf).await.map_err(|e| { - info!("Error on the network: {:?}", e); + warn!("Error on the network: {:?}", e); Error::Network })?; + + info!("Got packet: {:?} from addr {:?}", in_buf, addr); + Ok((size, Address::Udp(addr))) } pub async fn send(&self, addr: Address, out_buf: &[u8]) -> Result { match addr { - Address::Udp(addr) => self.socket.send_to(out_buf, addr).await.map_err(|e| { - info!("Error on the network: {:?}", e); - Error::Network - }), + Address::Udp(addr) => { + let len = self.socket.send_to(out_buf, addr).await.map_err(|e| { + warn!("Error on the network: {:?}", e); + Error::Network + })?; + + info!( + "Send packet: {:?} ({}/{}) to addr {:?}", + out_buf, + out_buf.len(), + len, + addr + ); + + Ok(len) + } } } } From 36011c2e3cfaa9785b648dbb5e93c088033cbf9b Mon Sep 17 00:00:00 2001 From: ivmarkov Date: Mon, 24 Apr 2023 19:28:22 +0000 Subject: [PATCH 16/72] Actually add the bonjour feature --- matter/Cargo.toml | 1 + 1 file changed, 1 insertion(+) diff --git a/matter/Cargo.toml b/matter/Cargo.toml index 85d113dc..4ddc25b7 100644 --- a/matter/Cargo.toml +++ b/matter/Cargo.toml @@ -19,6 +19,7 @@ default = ["std", "crypto_mbedtls", "nightly"] std = ["alloc", "env_logger", "chrono", "rand", "qrcode", "libmdns", "simple-mdns", "simple-dns", "smol"] alloc = [] nightly = [] +bonjour = ["astro-dnssd"] crypto_openssl = ["openssl", "foreign-types", "hmac", "sha2"] crypto_mbedtls = ["mbedtls", "alloc"] crypto_esp_mbedtls = ["esp-idf-sys"] From 8b3bb9527c8c7214173319434facf65d94fe8294 Mon Sep 17 00:00:00 2001 From: ivmarkov Date: Mon, 24 Apr 2023 21:41:16 +0000 Subject: [PATCH 17/72] Comm with chip-tool --- examples/onoff_light/src/main.rs | 15 +- matter/Cargo.toml | 6 +- matter/src/mdns.rs | 303 ++++++++++++++++++++++++------ matter/src/transport/exchange.rs | 29 ++- matter/src/transport/mgr.rs | 125 ++++++------ matter/src/transport/proto_ctx.rs | 3 +- matter/src/transport/udp.rs | 2 +- 7 files changed, 336 insertions(+), 147 deletions(-) diff --git a/examples/onoff_light/src/main.rs b/examples/onoff_light/src/main.rs index d4801887..123c0daf 100644 --- a/examples/onoff_light/src/main.rs +++ b/examples/onoff_light/src/main.rs @@ -28,8 +28,8 @@ use matter::data_model::sdm::dev_att::DevAttDataFetcher; use matter::interaction_model::core::InteractionModel; use matter::secure_channel::spake2p::VerifierData; use matter::transport::{ - mgr::RecvAction, mgr::TransportMgr, packet::Packet, packet::MAX_RX_BUF_SIZE, - packet::MAX_TX_BUF_SIZE, udp::UdpListener, + mgr::RecvAction, mgr::TransportMgr, packet::MAX_RX_BUF_SIZE, packet::MAX_TX_BUF_SIZE, + udp::UdpListener, }; mod dev_att; @@ -48,7 +48,7 @@ fn main() { device_name: "OnOff Light", }; - //let mut mdns = matter::mdns::bonjour::BonjourMdns::new().unwrap(); + //let mut mdns = matter::mdns::astro::AstroMdns::new().unwrap(); let mut mdns = matter::mdns::libmdns::LibMdns::new().unwrap(); let matter = Matter::new_default(&dev_info, &mut mdns); @@ -80,10 +80,7 @@ fn main() { let (len, addr) = udp.recv(&mut rx_buf).await.unwrap(); - let mut rx = Packet::new_rx(&mut rx_buf[..len]); - let mut tx = Packet::new_tx(&mut tx_buf); - - let mut completion = transport.recv(addr, &mut rx, &mut tx); + let mut completion = transport.recv(addr, &mut rx_buf[..len], &mut tx_buf); while let Some(action) = completion.next_action().unwrap() { match action { @@ -109,8 +106,8 @@ fn main() { InteractionModel(DataModel::new(matter.borrow(), &node, &mut handler)); if im.handle(&mut ctx).unwrap() { - if let Some(addr) = ctx.send().unwrap() { - udp.send(addr, ctx.tx.as_slice()).await.unwrap(); + if ctx.send().unwrap() { + udp.send(ctx.tx.peer, ctx.tx.as_slice()).await.unwrap(); } } } diff --git a/matter/Cargo.toml b/matter/Cargo.toml index 4ddc25b7..cfc99646 100644 --- a/matter/Cargo.toml +++ b/matter/Cargo.toml @@ -15,11 +15,10 @@ name = "matter" path = "src/lib.rs" [features] -default = ["std", "crypto_mbedtls", "nightly"] +default = ["std", "crypto_mbedtls"] std = ["alloc", "env_logger", "chrono", "rand", "qrcode", "libmdns", "simple-mdns", "simple-dns", "smol"] alloc = [] nightly = [] -bonjour = ["astro-dnssd"] crypto_openssl = ["openssl", "foreign-types", "hmac", "sha2"] crypto_mbedtls = ["mbedtls", "alloc"] crypto_esp_mbedtls = ["esp-idf-sys"] @@ -48,7 +47,8 @@ qrcode = { version = "0.12", default-features = false, optional = true } # Print libmdns = { version = "0.7", optional = true } simple-mdns = { version = "0.4", features = ["sync"], optional = true } simple-dns = { version = "0.5", optional = true } -astro-dnssd = { version = "0.3", optional = true } +astro-dnssd = { version = "0.3", optional = true } # On Linux needs avahi-compat-libdns_sd, i.e. on Ubuntu/Debian do `sudo apt-get install libavahi-compat-libdnssd-dev` +zeroconf = { version = "0.10", optional = true } smol = { version = "1.3.0", optional = true} # crypto diff --git a/matter/src/mdns.rs b/matter/src/mdns.rs index 1f218669..1b296187 100644 --- a/matter/src/mdns.rs +++ b/matter/src/mdns.rs @@ -23,13 +23,15 @@ pub trait Mdns { fn add( &mut self, name: &str, - service_type: &str, + service: &str, + protocol: &str, port: u16, service_subtypes: &[&str], txt_kvs: &[(&str, &str)], ) -> Result<(), Error>; - fn remove(&mut self, name: &str, service_type: &str, port: u16) -> Result<(), Error>; + fn remove(&mut self, name: &str, service: &str, protocol: &str, port: u16) + -> Result<(), Error>; } impl Mdns for &mut T @@ -39,16 +41,23 @@ where fn add( &mut self, name: &str, - service_type: &str, + service: &str, + protocol: &str, port: u16, service_subtypes: &[&str], txt_kvs: &[(&str, &str)], ) -> Result<(), Error> { - (**self).add(name, service_type, port, service_subtypes, txt_kvs) + (**self).add(name, service, protocol, port, service_subtypes, txt_kvs) } - fn remove(&mut self, name: &str, service_type: &str, port: u16) -> Result<(), Error> { - (**self).remove(name, service_type, port) + fn remove( + &mut self, + name: &str, + service: &str, + protocol: &str, + port: u16, + ) -> Result<(), Error> { + (**self).remove(name, service, protocol, port) } } @@ -58,7 +67,8 @@ impl Mdns for DummyMdns { fn add( &mut self, _name: &str, - _service_type: &str, + _service: &str, + _protocol: &str, _port: u16, _service_subtypes: &[&str], _txt_kvs: &[(&str, &str)], @@ -66,7 +76,13 @@ impl Mdns for DummyMdns { Ok(()) } - fn remove(&mut self, _name: &str, _service_type: &str, _port: u16) -> Result<(), Error> { + fn remove( + &mut self, + _name: &str, + _service: &str, + _protocol: &str, + _port: u16, + ) -> Result<(), Error> { Ok(()) } } @@ -117,7 +133,7 @@ impl<'a> MdnsMgr<'a> { match mode { ServiceMode::Commissioned => { self.mdns - .add(name, "_matter._tcp", self.matter_port, &[], &[]) + .add(name, "_matter", "_tcp", self.matter_port, &[], &[]) } ServiceMode::Commissionable(discriminator) => { let discriminator_str = Self::get_discriminator_str(discriminator); @@ -136,7 +152,8 @@ impl<'a> MdnsMgr<'a> { self.mdns.add( name, - "_matter._udp", + "_matterc", + "_udp", self.matter_port, &[ &self.get_long_service_subtype(discriminator), @@ -150,9 +167,11 @@ impl<'a> MdnsMgr<'a> { pub fn unpublish_service(&mut self, name: &str, mode: ServiceMode) -> Result<(), Error> { match mode { - ServiceMode::Commissioned => self.mdns.remove(name, "_matter._tcp", self.matter_port), + ServiceMode::Commissioned => { + self.mdns.remove(name, "_matter", "_tcp", self.matter_port) + } ServiceMode::Commissionable(_) => { - self.mdns.remove(name, "_matter._udp", self.matter_port) + self.mdns.remove(name, "_matterc", "_udp", self.matter_port) } } } @@ -193,8 +212,8 @@ impl<'a> MdnsMgr<'a> { } } -#[cfg(all(feature = "std", feature = "bonjour"))] -pub mod bonjour { +#[cfg(all(feature = "std", feature = "astro-dnssd"))] +pub mod astro { use std::collections::HashMap; use super::Mdns; @@ -205,15 +224,16 @@ pub mod bonjour { #[derive(Debug, Clone, Eq, PartialEq, Hash)] pub struct ServiceId { name: String, - service_type: String, + service: String, + protocol: String, port: u16, } - pub struct BonjourMdns { - services: HashMap, + pub struct AstroMdns { + services: HashMap, } - impl BonjourMdns { + impl AstroMdns { pub fn new() -> Result { Ok(Self { services: HashMap::new(), @@ -223,56 +243,65 @@ pub mod bonjour { pub fn add( &mut self, name: &str, - service_type: &str, + service: &str, + protocol: &str, port: u16, service_subtypes: &[&str], txt_kvs: &[(&str, &str)], ) -> Result<(), Error> { info!( - "Registering mDNS service {}/{}/{}", - name, service_type, port + "Registering mDNS service {}/{}.{} [{:?}]/{}", + name, service, protocol, service_subtypes, port ); - let _ = self.remove(name, service_type, port); + let _ = self.remove(name, service, protocol, port); let composite_service_type = if !service_subtypes.is_empty() { - format!("{}{}", service_type, service_subtypes.join(",")) + format!("{}.{},{}", service, protocol, service_subtypes.join(",")) } else { - service_type + format!("{}.{}", service, protocol) }; - let mut builder = DNSServiceBuilder::new(composite_service_type, port).with_name(name); + let mut builder = DNSServiceBuilder::new(&composite_service_type, port).with_name(name); for kvs in txt_kvs { info!("mDNS TXT key {} val {}", kvs.0, kvs.1); builder = builder.with_key_value(kvs.0.to_string(), kvs.1.to_string()); } - let service = builder.register().map_err(|_| Error::MdnsError)?; + let svc = builder.register().map_err(|_| Error::MdnsError)?; self.services.insert( ServiceId { name: name.into(), - service_type: service_type.into(), + service: service.into(), + protocol: protocol.into(), port, }, - service, + svc, ); Ok(()) } - pub fn remove(&mut self, name: &str, service_type: &str, port: u16) -> Result<(), Error> { + pub fn remove( + &mut self, + name: &str, + service: &str, + protocol: &str, + port: u16, + ) -> Result<(), Error> { let id = ServiceId { name: name.into(), - service_type: service_type.into(), + service: service.into(), + protocol: protocol.into(), port, }; if self.services.remove(&id).is_some() { info!( - "Deregistering mDNS service {}/{}/{}", - name, service_type, port + "Deregistering mDNS service {}/{}.{}/{}", + name, service, protocol, port ); } @@ -280,24 +309,172 @@ pub mod bonjour { } } - impl Mdns for BonjourMdns { + impl Mdns for AstroMdns { fn add( &mut self, name: &str, - service_type: &str, + service: &str, + protocol: &str, port: u16, service_subtypes: &[&str], txt_kvs: &[(&str, &str)], ) -> Result<(), Error> { - BonjourMdns::add(self, name, service_type, port, service_subtypes, txt_kvs) + AstroMdns::add( + self, + name, + service, + protocol, + port, + service_subtypes, + txt_kvs, + ) } - fn remove(&mut self, name: &str, service_type: &str, port: u16) -> Result<(), Error> { - BonjourMdns::remove(self, name, service_type, port) + fn remove( + &mut self, + name: &str, + service: &str, + protocol: &str, + port: u16, + ) -> Result<(), Error> { + AstroMdns::remove(self, name, service, protocol, port) } } } +// TODO: Maybe future +// #[cfg(all(feature = "std", feature = "zeroconf"))] +// pub mod zeroconf { +// use std::collections::HashMap; + +// use super::Mdns; +// use crate::error::Error; +// use log::info; +// use zeroconf::prelude::*; +// use zeroconf::{MdnsService, ServiceType, TxtRecord}; + +// #[derive(Debug, Clone, Eq, PartialEq, Hash)] +// pub struct ServiceId { +// name: String, +// service: String, +// protocol: String, +// port: u16, +// } + +// pub struct ZeroconfMdns { +// services: HashMap, +// } + +// impl ZeroconfMdns { +// pub fn new() -> Result { +// Ok(Self { +// services: HashMap::new(), +// }) +// } + +// pub fn add( +// &mut self, +// name: &str, +// service: &str, +// protocol: &str, +// port: u16, +// service_subtypes: &[&str], +// txt_kvs: &[(&str, &str)], +// ) -> Result<(), Error> { +// info!( +// "Registering mDNS service {}/{}.{} [{:?}]/{}", +// name, service, protocol, service_subtypes, port +// ); + +// let _ = self.remove(name, service, protocol, port); + +// let mut svc = MdnsService::new( +// ServiceType::with_sub_types(service, protocol, service_subtypes.into()).unwrap(), +// port, +// ); + +// let mut txt = TxtRecord::new(); + +// for kvs in txt_kvs { +// info!("mDNS TXT key {} val {}", kvs.0, kvs.1); +// txt.insert(kvs.0, kvs.1); +// } + +// svc.set_txt_record(txt); + +// //let event_loop = svc.register().map_err(|_| Error::MdnsError)?; + +// self.services.insert( +// ServiceId { +// name: name.into(), +// service: service.into(), +// protocol: protocol.into(), +// port, +// }, +// svc, +// ); + +// Ok(()) +// } + +// pub fn remove( +// &mut self, +// name: &str, +// service: &str, +// protocol: &str, +// port: u16, +// ) -> Result<(), Error> { +// let id = ServiceId { +// name: name.into(), +// service: service.into(), +// protocol: protocol.into(), +// port, +// }; + +// if self.services.remove(&id).is_some() { +// info!( +// "Deregistering mDNS service {}.{}/{}/{}", +// name, service, protocol, port +// ); +// } + +// Ok(()) +// } +// } + +// impl Mdns for ZeroconfMdns { +// fn add( +// &mut self, +// name: &str, +// service: &str, +// protocol: &str, +// port: u16, +// service_subtypes: &[&str], +// txt_kvs: &[(&str, &str)], +// ) -> Result<(), Error> { +// ZeroconfMdns::add( +// self, +// name, +// service, +// protocol, +// port, +// service_subtypes, +// txt_kvs, +// ) +// } + +// fn remove( +// &mut self, +// name: &str, +// service: &str, +// protocol: &str, +// port: u16, +// ) -> Result<(), Error> { +// ZeroconfMdns::remove(self, name, service, protocol, port) +// } +// } +// } + #[cfg(feature = "std")] pub mod libmdns { use super::Mdns; @@ -310,7 +487,8 @@ pub mod libmdns { #[derive(Debug, Clone, Eq, PartialEq, Hash)] pub struct ServiceId { name: String, - service_type: String, + service: String, + protocol: String, port: u16, } @@ -332,16 +510,17 @@ pub mod libmdns { pub fn add( &mut self, name: &str, - service_type: &str, + service: &str, + protocol: &str, port: u16, txt_kvs: &[(&str, &str)], ) -> Result<(), Error> { info!( - "Registering mDNS service {}/{}/{}", - name, service_type, port + "Registering mDNS service {}/{}.{}/{}", + name, service, protocol, port ); - let _ = self.remove(name, service_type, port); + let _ = self.remove(name, service, protocol, port); let mut properties = Vec::new(); for kvs in txt_kvs { @@ -350,8 +529,8 @@ pub mod libmdns { } let properties: Vec<&str> = properties.iter().map(|entry| entry.as_str()).collect(); - let service = self.responder.register( - service_type.to_owned(), + let svc = self.responder.register( + format!("{}.{}", service, protocol), name.to_owned(), port, &properties, @@ -360,26 +539,34 @@ pub mod libmdns { self.services.insert( ServiceId { name: name.into(), - service_type: service_type.into(), + service: service.into(), + protocol: protocol.into(), port, }, - service, + svc, ); Ok(()) } - pub fn remove(&mut self, name: &str, service_type: &str, port: u16) -> Result<(), Error> { + pub fn remove( + &mut self, + name: &str, + service: &str, + protocol: &str, + port: u16, + ) -> Result<(), Error> { let id = ServiceId { name: name.into(), - service_type: service_type.into(), + service: service.into(), + protocol: protocol.into(), port, }; if self.services.remove(&id).is_some() { info!( - "Deregistering mDNS service {}/{}/{}", - name, service_type, port + "Deregistering mDNS service {}/{}.{}/{}", + name, service, protocol, port ); } @@ -391,20 +578,28 @@ pub mod libmdns { fn add( &mut self, name: &str, - service_type: &str, + service: &str, + protocol: &str, port: u16, _service_subtypes: &[&str], txt_kvs: &[(&str, &str)], ) -> Result<(), Error> { - LibMdns::add(self, name, service_type, port, txt_kvs) + LibMdns::add(self, name, service, protocol, port, txt_kvs) } - fn remove(&mut self, name: &str, service_type: &str, port: u16) -> Result<(), Error> { - LibMdns::remove(self, name, service_type, port) + fn remove( + &mut self, + name: &str, + service: &str, + protocol: &str, + port: u16, + ) -> Result<(), Error> { + LibMdns::remove(self, name, service, protocol, port) } } } +// TODO: Maybe future // #[cfg(feature = "std")] // pub mod simplemdns { // use std::net::Ipv4Addr; diff --git a/matter/src/transport/exchange.rs b/matter/src/transport/exchange.rs index 8e3df105..333eab3f 100644 --- a/matter/src/transport/exchange.rs +++ b/matter/src/transport/exchange.rs @@ -31,10 +31,7 @@ use crate::utils::rand::Rand; use heapless::LinearMap; use super::session::CloneData; -use super::{ - mrp::ReliableMessage, network::Address, packet::Packet, session::SessionHandle, - session::SessionMgr, -}; +use super::{mrp::ReliableMessage, packet::Packet, session::SessionHandle, session::SessionMgr}; pub struct ExchangeCtx<'a> { pub exch: &'a mut Exchange, @@ -43,7 +40,7 @@ pub struct ExchangeCtx<'a> { } impl<'a> ExchangeCtx<'a> { - pub fn send(&mut self, tx: &mut Packet) -> Result, Error> { + pub fn send(&mut self, tx: &mut Packet) -> Result { self.exch.send(tx, &mut self.sess) } } @@ -201,10 +198,10 @@ impl Exchange { &mut self, tx: &mut Packet, session: &mut SessionHandle, - ) -> Result, Error> { + ) -> Result { if self.state == State::Terminate { info!("Skipping tx for terminated exchange {}", self.id); - return Ok(None); + return Ok(false); } trace!("payload: {:x?}", tx.as_mut_slice()); @@ -224,7 +221,7 @@ impl Exchange { self.mrp.pre_send(tx)?; session.send(tx)?; - Ok(Some(session.get_peer_addr())) + Ok(true) } } @@ -359,13 +356,11 @@ impl ExchangeMgr { } } - pub fn send(&mut self, exch_id: u16, tx: &mut Packet) -> Result { + pub fn send(&mut self, exch_id: u16, tx: &mut Packet) -> Result { let exchange = ExchangeMgr::_get_with_id(&mut self.exchanges, exch_id).ok_or(Error::NoExchange)?; let mut session = self.sess_mgr.get_session_handle(exchange.sess_idx); - exchange.send(tx, &mut session)?; - - Ok(session.get_peer_addr()) + exchange.send(tx, &mut session) } pub fn purge(&mut self) { @@ -388,7 +383,7 @@ impl ExchangeMgr { .map(|(exch_id, _)| *exch_id) } - pub fn evict_session(&mut self, tx: &mut Packet) -> Result, Error> { + pub fn evict_session(&mut self, tx: &mut Packet) -> Result { if let Some(index) = self.sess_mgr.get_session_for_eviction() { info!("Sessions full, vacating session with index: {}", index); // If we enter here, we have an LRU session that needs to be reclaimed @@ -431,13 +426,11 @@ impl ExchangeMgr { self.exchanges.remove(&exch_id); } - let addr = session.get_peer_addr(); - self.sess_mgr.remove(index); - Ok(Some(addr)) + Ok(true) } else { - Ok(None) + Ok(false) } } @@ -571,7 +564,7 @@ mod tests { let mut buf = [0; MAX_TX_BUF_SIZE]; let tx = &mut Packet::new_tx(&mut buf); let evicted = mgr.evict_session(tx).unwrap(); - assert!(evicted.is_some()); + assert!(evicted); let session = mgr .add_session(&get_clone_data(new_peer_sess_id, new_local_sess_id)) diff --git a/matter/src/transport/mgr.rs b/matter/src/transport/mgr.rs index c4bc2e7b..e33e30e9 100644 --- a/matter/src/transport/mgr.rs +++ b/matter/src/transport/mgr.rs @@ -28,11 +28,10 @@ use crate::secure_channel::pake::PaseMgr; use crate::secure_channel::common::PROTO_ID_SECURE_CHANNEL; use crate::secure_channel::core::SecureChannel; use crate::transport::mrp::ReliableMessage; -use crate::transport::{exchange, packet::Packet}; +use crate::transport::{exchange, network::Address, packet::Packet}; use crate::utils::epoch::{Epoch, UtcCalendar}; use crate::utils::rand::Rand; -use super::network::Address; use super::proto_ctx::ProtoCtx; use super::session::CloneData; @@ -52,9 +51,8 @@ pub enum RecvAction<'r, 'p> { pub struct RecvCompletion<'r, 'a, 'p> { mgr: &'r mut TransportMgr<'a>, - addr: Address, // TODO: Not used yet - rx: &'r mut Packet<'p>, - tx: &'r mut Packet<'p>, + rx: Packet<'p>, + tx: Packet<'p>, state: RecvState, } @@ -73,91 +71,94 @@ impl<'r, 'a, 'p> RecvCompletion<'r, 'a, 'p> { fn maybe_next_action(&mut self) -> Result>>, Error> { self.mgr.exch_mgr.purge(); - match core::mem::replace(&mut self.state, RecvState::New) { + let (state, next) = match core::mem::replace(&mut self.state, RecvState::New) { RecvState::New => { - self.mgr.exch_mgr.get_sess_mgr().decode(self.rx)?; - self.state = RecvState::OpenExchange; - Ok(None) + self.mgr.exch_mgr.get_sess_mgr().decode(&mut self.rx)?; + (RecvState::OpenExchange, None) } - RecvState::OpenExchange => match self.mgr.exch_mgr.recv(self.rx) { + RecvState::OpenExchange => match self.mgr.exch_mgr.recv(&mut self.rx) { Ok(Some(exch_ctx)) => { if self.rx.get_proto_id() == PROTO_ID_SECURE_CHANNEL { - let mut proto_ctx = ProtoCtx::new(exch_ctx, self.rx, self.tx); + let mut proto_ctx = ProtoCtx::new(exch_ctx, &self.rx, &mut self.tx); let (reply, clone_data) = self.mgr.secure_channel.handle(&mut proto_ctx)?; - if let Some(clone_data) = clone_data { - self.state = RecvState::AddSession(clone_data); + let state = if let Some(clone_data) = clone_data { + RecvState::AddSession(clone_data) } else { - self.state = RecvState::Ack; - } - - let addr = if reply { proto_ctx.send()? } else { None }; - - if let Some(addr) = addr { - Ok(Some(Some(RecvAction::Send(addr, self.tx.as_slice())))) + RecvState::Ack + }; + + if reply { + if proto_ctx.send()? { + ( + state, + Some(Some(RecvAction::Send(self.tx.peer, self.tx.as_slice()))), + ) + } else { + (state, None) + } } else { - Ok(None) + (state, None) } } else { - let proto_ctx = ProtoCtx::new(exch_ctx, self.rx, self.tx); - self.state = RecvState::Ack; + let proto_ctx = ProtoCtx::new(exch_ctx, &self.rx, &mut self.tx); - Ok(Some(Some(RecvAction::Interact(proto_ctx)))) + (RecvState::Ack, Some(Some(RecvAction::Interact(proto_ctx)))) } } - Ok(None) => { - self.state = RecvState::Ack; - Ok(None) - } - Err(Error::NoSpace) => { - self.state = RecvState::EvictSession; - Ok(None) - } - Err(err) => Err(err), + Ok(None) => (RecvState::Ack, None), + Err(Error::Duplicate) => (RecvState::Ack, Some(None)), + Err(Error::NoSpace) => (RecvState::EvictSession, None), + Err(err) => Err(err)?, }, RecvState::AddSession(clone_data) => match self.mgr.exch_mgr.add_session(&clone_data) { - Ok(_) => { - self.state = RecvState::Ack; - Ok(None) - } - Err(Error::NoSpace) => { - self.state = RecvState::EvictSession2(clone_data); - Ok(None) - } - Err(err) => Err(err), + Ok(_) => (RecvState::Ack, None), + Err(Error::NoSpace) => (RecvState::EvictSession2(clone_data), None), + Err(err) => Err(err)?, }, RecvState::EvictSession => { - let addr = self.mgr.exch_mgr.evict_session(self.tx)?; - self.state = RecvState::OpenExchange; - if let Some(addr) = addr { - Ok(Some(Some(RecvAction::Send(addr, self.tx.as_slice())))) + if self.mgr.exch_mgr.evict_session(&mut self.tx)? { + ( + RecvState::OpenExchange, + Some(Some(RecvAction::Send(self.tx.peer, self.tx.as_slice()))), + ) } else { - Ok(None) + (RecvState::EvictSession, None) } } RecvState::EvictSession2(clone_data) => { - let addr = self.mgr.exch_mgr.evict_session(self.tx)?; - self.state = RecvState::AddSession(clone_data); - if let Some(addr) = addr { - Ok(Some(Some(RecvAction::Send(addr, self.tx.as_slice())))) + if self.mgr.exch_mgr.evict_session(&mut self.tx)? { + ( + RecvState::AddSession(clone_data), + Some(Some(RecvAction::Send(self.tx.peer, self.tx.as_slice()))), + ) } else { - Ok(None) + (RecvState::EvictSession2(clone_data), None) } } RecvState::Ack => { if let Some(exch_id) = self.mgr.exch_mgr.pending_ack() { info!("Sending MRP Standalone ACK for exch {}", exch_id); - ReliableMessage::prepare_ack(exch_id, self.tx); + ReliableMessage::prepare_ack(exch_id, &mut self.tx); - let addr = self.mgr.exch_mgr.send(exch_id, self.tx)?; - Ok(Some(Some(RecvAction::Send(addr, self.tx.as_slice())))) + if self.mgr.exch_mgr.send(exch_id, &mut self.tx)? { + ( + RecvState::Ack, + Some(Some(RecvAction::Send(self.tx.peer, self.tx.as_slice()))), + ) + } else { + (RecvState::Ack, None) + } } else { - Ok(Some(None)) + (RecvState::Ack, Some(None)) } } - } + }; + + self.state = state; + Ok(next) } } @@ -232,12 +233,16 @@ impl<'a> TransportMgr<'a> { pub fn recv<'r, 'p>( &'r mut self, addr: Address, - rx: &'r mut Packet<'p>, - tx: &'r mut Packet<'p>, + rx_buf: &'p mut [u8], + tx_buf: &'p mut [u8], ) -> RecvCompletion<'r, 'a, 'p> { + let mut rx = Packet::new_rx(rx_buf); + let tx = Packet::new_tx(tx_buf); + + rx.peer = addr; + RecvCompletion { mgr: self, - addr, rx, tx, state: RecvState::New, diff --git a/matter/src/transport/proto_ctx.rs b/matter/src/transport/proto_ctx.rs index c7b95db3..b7374eca 100644 --- a/matter/src/transport/proto_ctx.rs +++ b/matter/src/transport/proto_ctx.rs @@ -18,7 +18,6 @@ use crate::error::Error; use super::exchange::ExchangeCtx; -use super::network::Address; use super::packet::Packet; /// This is the context in which a receive packet is being processed @@ -36,7 +35,7 @@ impl<'a, 'b> ProtoCtx<'a, 'b> { Self { exch_ctx, rx, tx } } - pub fn send(&mut self) -> Result, Error> { + pub fn send(&mut self) -> Result { self.exch_ctx.exch.send(self.tx, &mut self.exch_ctx.sess) } } diff --git a/matter/src/transport/udp.rs b/matter/src/transport/udp.rs index b2dd1dc1..b3c4c484 100644 --- a/matter/src/transport/udp.rs +++ b/matter/src/transport/udp.rs @@ -56,7 +56,7 @@ impl UdpListener { Error::Network })?; - info!("Got packet: {:?} from addr {:?}", in_buf, addr); + info!("Got packet: {:?} from addr {:?}", &in_buf[..size], addr); Ok((size, Address::Udp(addr))) } From 7ef7e93eb49a3df26d4c21b019060e3d0703700d Mon Sep 17 00:00:00 2001 From: ivmarkov Date: Tue, 25 Apr 2023 07:22:58 +0000 Subject: [PATCH 18/72] Heap-allocated packets not necessary; no_std and no-alloc build supported end-to-end --- Cargo.toml | 2 +- boxslab/Cargo.toml | 9 -- boxslab/src/lib.rs | 237 ------------------------------- examples/onoff_light/src/main.rs | 2 +- matter/Cargo.toml | 2 +- matter/src/core.rs | 8 +- matter/src/error.rs | 6 +- matter/src/lib.rs | 1 - matter/src/sys/mod.rs | 20 --- matter/src/transport/mod.rs | 1 + matter/src/transport/network.rs | 3 + matter/src/transport/packet.rs | 56 +------- 12 files changed, 14 insertions(+), 333 deletions(-) delete mode 100644 boxslab/Cargo.toml delete mode 100644 boxslab/src/lib.rs delete mode 100644 matter/src/sys/mod.rs diff --git a/Cargo.toml b/Cargo.toml index 7b2660c7..268671d4 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,4 +1,4 @@ [workspace] -members = ["matter", "matter_macro_derive", "boxslab", "tools/tlv_tool"] +members = ["matter", "matter_macro_derive", "tools/tlv_tool"] exclude = ["examples/*"] diff --git a/boxslab/Cargo.toml b/boxslab/Cargo.toml deleted file mode 100644 index fa4f4302..00000000 --- a/boxslab/Cargo.toml +++ /dev/null @@ -1,9 +0,0 @@ -[package] -name = "boxslab" -version = "0.1.0" -edition = "2021" - -# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html - -[dependencies] -bitmaps={version="3.2.0", features=[]} diff --git a/boxslab/src/lib.rs b/boxslab/src/lib.rs deleted file mode 100644 index f25cbd40..00000000 --- a/boxslab/src/lib.rs +++ /dev/null @@ -1,237 +0,0 @@ -/* - * - * Copyright (c) 2020-2022 Project CHIP Authors - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -use std::{ - mem::MaybeUninit, - ops::{Deref, DerefMut}, - sync::Mutex, -}; - -// TODO: why is max bitmap size 64 a correct max size? Could we match -// boxslabs instead or store used/not used inside the box slabs themselves? -const MAX_BITMAP_SIZE: usize = 64; -pub struct Bitmap { - inner: bitmaps::Bitmap, - max_size: usize, -} - -impl Bitmap { - pub fn new(max_size: usize) -> Self { - assert!(max_size <= MAX_BITMAP_SIZE); - Bitmap { - inner: bitmaps::Bitmap::new(), - max_size, - } - } - - pub fn set(&mut self, index: usize) -> bool { - assert!(index < self.max_size); - self.inner.set(index, true) - } - - pub fn reset(&mut self, index: usize) -> bool { - assert!(index < self.max_size); - self.inner.set(index, false) - } - - pub fn first_false_index(&self) -> Option { - match self.inner.first_false_index() { - Some(idx) if idx < self.max_size => Some(idx), - _ => None, - } - } - - pub fn is_empty(&self) -> bool { - self.inner.is_empty() - } - - pub fn is_full(&self) -> bool { - self.first_false_index().is_none() - } -} - -#[macro_export] -macro_rules! box_slab { - ($name:ident,$t:ty,$v:expr) => { - use std::mem::MaybeUninit; - use std::sync::Once; - use $crate::{BoxSlab, Slab, SlabPool}; - - pub struct $name; - - impl SlabPool for $name { - type SlabType = $t; - fn get_slab() -> &'static Slab { - const MAYBE_INIT: MaybeUninit<$t> = MaybeUninit::uninit(); - static mut SLAB_POOL: [MaybeUninit<$t>; $v] = [MAYBE_INIT; $v]; - static mut SLAB_SPACE: Option> = None; - static mut INIT: Once = Once::new(); - unsafe { - INIT.call_once(|| { - SLAB_SPACE = Some(Slab::<$name>::init(&mut SLAB_POOL, $v)); - }); - SLAB_SPACE.as_ref().unwrap() - } - } - } - }; -} - -pub trait SlabPool { - type SlabType: 'static; - fn get_slab() -> &'static Slab - where - Self: Sized; -} - -pub struct Inner { - pool: &'static mut [MaybeUninit], - map: Bitmap, -} - -// TODO: Instead of a mutex, we should replace this with a CAS loop -pub struct Slab(Mutex>); - -impl Slab { - pub fn init(pool: &'static mut [MaybeUninit], size: usize) -> Self { - Self(Mutex::new(Inner { - pool, - map: Bitmap::new(size), - })) - } - - pub fn try_new(new_object: T::SlabType) -> Option> { - let slab = T::get_slab(); - let mut inner = slab.0.lock().unwrap(); - if let Some(index) = inner.map.first_false_index() { - inner.map.set(index); - inner.pool[index].write(new_object); - let cell_ptr = unsafe { &mut *inner.pool[index].as_mut_ptr() }; - Some(BoxSlab { - data: cell_ptr, - index, - }) - } else { - None - } - } - - pub fn free(&self, index: usize) { - let mut inner = self.0.lock().unwrap(); - inner.map.reset(index); - let old_value = std::mem::replace(&mut inner.pool[index], MaybeUninit::uninit()); - let _old_value = unsafe { old_value.assume_init() }; - // This will drop the old_value - } -} - -pub struct BoxSlab { - // Because the data is a reference within the MaybeUninit, we don't have a mechanism - // to go out to the MaybeUninit from this reference. Hence this index - index: usize, - // TODO: We should figure out a way to get rid of the index too - data: &'static mut T::SlabType, -} - -impl Drop for BoxSlab { - fn drop(&mut self) { - T::get_slab().free(self.index); - } -} - -impl Deref for BoxSlab { - type Target = T::SlabType; - fn deref(&self) -> &Self::Target { - self.data - } -} - -impl DerefMut for BoxSlab { - fn deref_mut(&mut self) -> &mut Self::Target { - self.data - } -} - -#[cfg(test)] -mod tests { - use std::{ops::Deref, sync::Arc}; - - pub struct Test { - val: Arc, - } - - box_slab!(TestSlab, Test, 3); - - #[test] - fn simple_alloc_free() { - { - let a = Slab::::try_new(Test { val: Arc::new(10) }).unwrap(); - assert_eq!(*a.val.deref(), 10); - let inner = TestSlab::get_slab().0.lock().unwrap(); - assert!(!inner.map.is_empty()); - } - // Validates that the 'Drop' got executed - let inner = TestSlab::get_slab().0.lock().unwrap(); - assert!(inner.map.is_empty()); - println!("Box Size {}", std::mem::size_of::>()); - println!("BoxSlab Size {}", std::mem::size_of::>()); - } - - #[test] - fn alloc_full_block() { - { - let a = Slab::::try_new(Test { val: Arc::new(10) }).unwrap(); - let b = Slab::::try_new(Test { val: Arc::new(11) }).unwrap(); - let c = Slab::::try_new(Test { val: Arc::new(12) }).unwrap(); - // Test that at overflow, we return None - assert!(Slab::::try_new(Test { val: Arc::new(13) }).is_none(),); - assert_eq!(*b.val.deref(), 11); - - { - let inner = TestSlab::get_slab().0.lock().unwrap(); - // Test that the bitmap is marked as full - assert!(inner.map.is_full()); - } - - // Purposefully drop, to test that new allocation is possible - std::mem::drop(b); - let d = Slab::::try_new(Test { val: Arc::new(21) }).unwrap(); - assert_eq!(*d.val.deref(), 21); - - // Ensure older allocations are still valid - assert_eq!(*a.val.deref(), 10); - assert_eq!(*c.val.deref(), 12); - } - - // Validates that the 'Drop' got executed - test that the bitmap is empty - let inner = TestSlab::get_slab().0.lock().unwrap(); - assert!(inner.map.is_empty()); - } - - #[test] - fn test_drop_logic() { - let root = Arc::new(10); - { - let _a = Slab::::try_new(Test { val: root.clone() }).unwrap(); - let _b = Slab::::try_new(Test { val: root.clone() }).unwrap(); - let _c = Slab::::try_new(Test { val: root.clone() }).unwrap(); - assert_eq!(Arc::strong_count(&root), 4); - } - // Test that Drop was correctly called on all the members of the pool - assert_eq!(Arc::strong_count(&root), 1); - } -} diff --git a/examples/onoff_light/src/main.rs b/examples/onoff_light/src/main.rs index 123c0daf..1748b599 100644 --- a/examples/onoff_light/src/main.rs +++ b/examples/onoff_light/src/main.rs @@ -51,7 +51,7 @@ fn main() { //let mut mdns = matter::mdns::astro::AstroMdns::new().unwrap(); let mut mdns = matter::mdns::libmdns::LibMdns::new().unwrap(); - let matter = Matter::new_default(&dev_info, &mut mdns); + let matter = Matter::new_default(&dev_info, &mut mdns, matter::transport::udp::MATTER_PORT); let dev_att = dev_att::HardCodedDevAtt::new(); diff --git a/matter/Cargo.toml b/matter/Cargo.toml index cfc99646..9f5503b3 100644 --- a/matter/Cargo.toml +++ b/matter/Cargo.toml @@ -25,7 +25,6 @@ crypto_esp_mbedtls = ["esp-idf-sys"] crypto_rustcrypto = ["sha2", "hmac", "pbkdf2", "hkdf", "aes", "ccm", "p256", "elliptic-curve", "crypto-bigint", "x509-cert"] [dependencies] -boxslab = { path = "../boxslab" } matter_macro_derive = { path = "../matter_macro_derive" } bitflags = "1.3" byteorder = "1.4.3" @@ -35,6 +34,7 @@ num-derive = "0.3.3" num-traits = "0.2.15" strum = { version = "0.24", features = ["derive"], default-features = false, no-default-feature = true } log = { version = "0.4.17", features = ["max_level_debug", "release_max_level_debug"] } +no-std-net = "0.6" subtle = "2.4.1" safemem = "0.3.3" colored = "2.0.0" # TODO: Requires STD diff --git a/matter/src/core.rs b/matter/src/core.rs index 41c53071..0939a4a5 100644 --- a/matter/src/core.rs +++ b/matter/src/core.rs @@ -25,7 +25,6 @@ use crate::{ mdns::{Mdns, MdnsMgr}, pairing::{print_pairing_code_and_qr, DiscoveryCapabilities}, secure_channel::{pake::PaseMgr, spake2p::VerifierData}, - transport::udp::MATTER_PORT, utils::{ epoch::{Epoch, UtcCalendar}, rand::Rand, @@ -55,11 +54,11 @@ pub struct Matter<'a> { impl<'a> Matter<'a> { #[cfg(feature = "std")] - pub fn new_default(dev_det: &'a BasicInfoConfig, mdns: &'a mut dyn Mdns) -> Self { + pub fn new_default(dev_det: &'a BasicInfoConfig, mdns: &'a mut dyn Mdns, port: u16) -> Self { use crate::utils::epoch::{sys_epoch, sys_utc_calendar}; use crate::utils::rand::sys_rand; - Self::new(dev_det, mdns, sys_epoch, sys_rand, sys_utc_calendar) + Self::new(dev_det, mdns, sys_epoch, sys_rand, sys_utc_calendar, port) } /// Creates a new Matter object @@ -74,6 +73,7 @@ impl<'a> Matter<'a> { epoch: Epoch, rand: Rand, utc_calendar: UtcCalendar, + port: u16, ) -> Self { Self { fabric_mgr: RefCell::new(FabricMgr::new()), @@ -84,7 +84,7 @@ impl<'a> Matter<'a> { dev_det.vid, dev_det.pid, dev_det.device_name, - MATTER_PORT, + port, mdns, )), epoch, diff --git a/matter/src/error.rs b/matter/src/error.rs index e644a7aa..d2053d16 100644 --- a/matter/src/error.rs +++ b/matter/src/error.rs @@ -17,8 +17,6 @@ use core::{array::TryFromSliceError, fmt, str::Utf8Error}; -use log::error; - #[derive(Debug, PartialEq, Clone, Copy)] pub enum Error { AttributeNotFound, @@ -122,7 +120,7 @@ impl From> for Error { #[cfg(feature = "crypto_openssl")] impl From for Error { fn from(e: openssl::error::ErrorStack) -> Self { - error!("Error in TLS: {}", e); + ::log::error!("Error in TLS: {}", e); Self::TLSStack } } @@ -130,7 +128,7 @@ impl From for Error { #[cfg(feature = "crypto_mbedtls")] impl From for Error { fn from(e: mbedtls::Error) -> Self { - error!("Error in TLS: {}", e); + ::log::error!("Error in TLS: {}", e); Self::TLSStack } } diff --git a/matter/src/lib.rs b/matter/src/lib.rs index 0d99cdb7..1d7e5d4a 100644 --- a/matter/src/lib.rs +++ b/matter/src/lib.rs @@ -85,7 +85,6 @@ pub mod mdns; pub mod pairing; pub mod persist; pub mod secure_channel; -pub mod sys; pub mod tlv; pub mod transport; pub mod utils; diff --git a/matter/src/sys/mod.rs b/matter/src/sys/mod.rs deleted file mode 100644 index a80b8759..00000000 --- a/matter/src/sys/mod.rs +++ /dev/null @@ -1,20 +0,0 @@ -/* - * - * Copyright (c) 2020-2022 Project CHIP Authors - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -// The Packet Pool that is allocated from. POSIX systems can use -// higher values unlike embedded systems -pub const MAX_PACKET_POOL_SIZE: usize = 25; diff --git a/matter/src/transport/mod.rs b/matter/src/transport/mod.rs index 1a81c75c..0b6453ee 100644 --- a/matter/src/transport/mod.rs +++ b/matter/src/transport/mod.rs @@ -25,4 +25,5 @@ pub mod plain_hdr; pub mod proto_ctx; pub mod proto_hdr; pub mod session; +#[cfg(feature = "std")] pub mod udp; diff --git a/matter/src/transport/network.rs b/matter/src/transport/network.rs index 91645de6..6cda9bcd 100644 --- a/matter/src/transport/network.rs +++ b/matter/src/transport/network.rs @@ -16,6 +16,9 @@ */ use core::fmt::{Debug, Display}; +#[cfg(not(feature = "std"))] +use no_std_net::{IpAddr, Ipv4Addr, SocketAddr}; +#[cfg(feature = "std")] use std::net::{IpAddr, Ipv4Addr, SocketAddr}; #[derive(PartialEq, Copy, Clone)] diff --git a/matter/src/transport/packet.rs b/matter/src/transport/packet.rs index b2ca7aad..a86bf697 100644 --- a/matter/src/transport/packet.rs +++ b/matter/src/transport/packet.rs @@ -15,14 +15,10 @@ * limitations under the License. */ -use log::{error, trace}; -use std::sync::Mutex; - -use boxslab::box_slab; +use log::error; use crate::{ error::Error, - sys::MAX_PACKET_POOL_SIZE, utils::{parsebuf::ParseBuf, writebuf::WriteBuf}, }; @@ -34,54 +30,6 @@ use super::{ pub const MAX_RX_BUF_SIZE: usize = 1583; pub const MAX_TX_BUF_SIZE: usize = 1280 - 40/*IPV6 header size*/ - 8/*UDP header size*/; -type Buffer = [u8; MAX_RX_BUF_SIZE]; - -// TODO: I am not very happy with this construction, need to find another way to do this -pub struct BufferPool { - buffers: [Option; MAX_PACKET_POOL_SIZE], -} - -impl BufferPool { - const INIT: Option = None; - fn get() -> &'static Mutex { - static mut BUFFER_HOLDER: Option> = None; - static ONCE: Once = Once::new(); - unsafe { - ONCE.call_once(|| { - BUFFER_HOLDER = Some(Mutex::new(BufferPool { - buffers: [BufferPool::INIT; MAX_PACKET_POOL_SIZE], - })); - }); - BUFFER_HOLDER.as_ref().unwrap() - } - } - - pub fn alloc() -> Option<(usize, &'static mut Buffer)> { - trace!("Buffer Alloc called\n"); - - let mut pool = BufferPool::get().lock().unwrap(); - for i in 0..MAX_PACKET_POOL_SIZE { - if pool.buffers[i].is_none() { - pool.buffers[i] = Some([0; MAX_RX_BUF_SIZE]); - // Sigh! to by-pass the borrow-checker telling us we are stealing a mutable reference - // from under the lock - // In this case the lock only protects against the setting of Some/None, - // the objects then are independently accessed in a unique way - let buffer = unsafe { &mut *(pool.buffers[i].as_mut().unwrap() as *mut Buffer) }; - return Some((i, buffer)); - } - } - None - } - - pub fn free(index: usize) { - trace!("Buffer Free called\n"); - let mut pool = BufferPool::get().lock().unwrap(); - if pool.buffers[index].is_some() { - pool.buffers[index] = None; - } - } -} #[derive(PartialEq)] enum RxState { @@ -229,5 +177,3 @@ impl<'a> Packet<'a> { } } } - -box_slab!(PacketPool, Packet<'static>, MAX_PACKET_POOL_SIZE); From 9a23a2af2d263deb149c4d7a623859dc227cb3bf Mon Sep 17 00:00:00 2001 From: ivmarkov Date: Tue, 25 Apr 2023 19:47:48 +0000 Subject: [PATCH 19/72] Bugfix: arm failsafe was reporting wrong status --- .../data_model/sdm/general_commissioning.rs | 49 ++++++++++++------- 1 file changed, 32 insertions(+), 17 deletions(-) diff --git a/matter/src/data_model/sdm/general_commissioning.rs b/matter/src/data_model/sdm/general_commissioning.rs index 0c007d1b..3882f30b 100644 --- a/matter/src/data_model/sdm/general_commissioning.rs +++ b/matter/src/data_model/sdm/general_commissioning.rs @@ -22,7 +22,6 @@ use crate::data_model::objects::*; use crate::data_model::sdm::failsafe::FailSafe; use crate::interaction_model::core::Transaction; use crate::tlv::{FromTLV, TLVElement, TLVWriter, TagType, ToTLV, UtfStr}; -use crate::transport::session::Session; use crate::utils::rand::Rand; use crate::{attribute_enum, cmd_enter}; use crate::{command_enum, error::*}; @@ -173,18 +172,18 @@ impl GenCommCluster { pub fn invoke( &mut self, - session: &mut Session, + transaction: &mut Transaction, cmd: &CmdDetails, data: &TLVElement, encoder: CmdDataEncoder, ) -> Result<(), Error> { match cmd.cmd_id.try_into()? { - Commands::ArmFailsafe => self.handle_command_armfailsafe(session, data, encoder)?, + Commands::ArmFailsafe => self.handle_command_armfailsafe(transaction, data, encoder)?, Commands::SetRegulatoryConfig => { - self.handle_command_setregulatoryconfig(data, encoder)? + self.handle_command_setregulatoryconfig(transaction, data, encoder)? } Commands::CommissioningComplete => { - self.handle_command_commissioningcomplete(session, encoder)?; + self.handle_command_commissioningcomplete(transaction, encoder)?; } } @@ -195,7 +194,7 @@ impl GenCommCluster { fn handle_command_armfailsafe( &mut self, - session: &Session, + transaction: &mut Transaction, data: &TLVElement, encoder: CmdDataEncoder, ) -> Result<(), Error> { @@ -203,23 +202,33 @@ impl GenCommCluster { let p = FailSafeParams::from_tlv(data)?; - self.failsafe + let status = if self + .failsafe .borrow_mut() - .arm(p.expiry_len, session.get_session_mode()) - .map_err(|e| e.remap(|_| true, Error::Busy))?; + .arm(p.expiry_len, transaction.session().get_session_mode()) + .is_err() + { + CommissioningError::ErrBusyWithOtherAdmin as u8 + } else { + CommissioningError::Ok as u8 + }; let cmd_data = CommonResponse { - error_code: CommissioningError::ErrBusyWithOtherAdmin as u8, + error_code: status, debug_txt: UtfStr::new(b""), }; encoder .with_command(RespCommands::ArmFailsafeResp as _)? - .set(cmd_data) + .set(cmd_data)?; + + transaction.complete(); + Ok(()) } fn handle_command_setregulatoryconfig( &mut self, + transaction: &mut Transaction, data: &TLVElement, encoder: CmdDataEncoder, ) -> Result<(), Error> { @@ -238,19 +247,22 @@ impl GenCommCluster { encoder .with_command(RespCommands::SetRegulatoryConfigResp as _)? - .set(cmd_data) + .set(cmd_data)?; + + transaction.complete(); + Ok(()) } fn handle_command_commissioningcomplete( &mut self, - session: &Session, + transaction: &mut Transaction, encoder: CmdDataEncoder, ) -> Result<(), Error> { cmd_enter!("Commissioning Complete"); let mut status: u8 = CommissioningError::Ok as u8; // Has to be a Case Session - if session.get_local_fabric_idx().is_none() { + if transaction.session().get_local_fabric_idx().is_none() { status = CommissioningError::ErrInvalidAuth as u8; } @@ -259,7 +271,7 @@ impl GenCommCluster { if self .failsafe .borrow_mut() - .disarm(session.get_session_mode()) + .disarm(transaction.session().get_session_mode()) .is_err() { status = CommissioningError::ErrInvalidAuth as u8; @@ -272,7 +284,10 @@ impl GenCommCluster { encoder .with_command(RespCommands::CommissioningCompleteResp as _)? - .set(cmd_data) + .set(cmd_data)?; + + transaction.complete(); + Ok(()) } } @@ -288,7 +303,7 @@ impl Handler for GenCommCluster { data: &TLVElement, encoder: CmdDataEncoder, ) -> Result<(), Error> { - GenCommCluster::invoke(self, transaction.session_mut(), cmd, data, encoder) + GenCommCluster::invoke(self, transaction, cmd, data, encoder) } } From b2805570ea8f10493ae7ce52821efaae177b60c4 Mon Sep 17 00:00:00 2001 From: ivmarkov Date: Tue, 25 Apr 2023 19:48:04 +0000 Subject: [PATCH 20/72] Restore transaction completion code --- matter/src/data_model/cluster_on_off.rs | 7 +- matter/src/data_model/sdm/noc.rs | 106 ++++++++++++++---------- matter/src/interaction_model/core.rs | 10 ++- matter/tests/common/im_engine.rs | 9 +- 4 files changed, 82 insertions(+), 50 deletions(-) diff --git a/matter/src/data_model/cluster_on_off.rs b/matter/src/data_model/cluster_on_off.rs index 9b173673..1a26522a 100644 --- a/matter/src/data_model/cluster_on_off.rs +++ b/matter/src/data_model/cluster_on_off.rs @@ -112,6 +112,7 @@ impl OnOffCluster { pub fn invoke( &mut self, + transaction: &mut Transaction, cmd: &CmdDetails, _data: &TLVElement, _encoder: CmdDataEncoder, @@ -131,6 +132,8 @@ impl OnOffCluster { } } + transaction.complete(); + self.data_ver.changed(); Ok(()) @@ -148,12 +151,12 @@ impl Handler for OnOffCluster { fn invoke( &mut self, - _transaction: &mut Transaction, + transaction: &mut Transaction, cmd: &CmdDetails, data: &TLVElement, encoder: CmdDataEncoder, ) -> Result<(), Error> { - OnOffCluster::invoke(self, cmd, data, encoder) + OnOffCluster::invoke(self, transaction, cmd, data, encoder) } } diff --git a/matter/src/data_model/sdm/noc.rs b/matter/src/data_model/sdm/noc.rs index b2dcee21..3bf6d858 100644 --- a/matter/src/data_model/sdm/noc.rs +++ b/matter/src/data_model/sdm/noc.rs @@ -27,7 +27,7 @@ use crate::fabric::{Fabric, FabricMgr, MAX_SUPPORTED_FABRICS}; use crate::interaction_model::core::Transaction; use crate::mdns::MdnsMgr; use crate::tlv::{FromTLV, OctetStr, TLVElement, TLVWriter, TagType, ToTLV, UtfStr}; -use crate::transport::session::{Session, SessionMode}; +use crate::transport::session::SessionMode; use crate::utils::epoch::Epoch; use crate::utils::rand::Rand; use crate::utils::writebuf::WriteBuf; @@ -294,21 +294,17 @@ impl<'a> NocCluster<'a> { encoder: CmdDataEncoder, ) -> Result<(), Error> { match cmd.cmd_id.try_into()? { - Commands::AddNOC => { - self.handle_command_addnoc(transaction.session_mut(), data, encoder)? - } - Commands::CSRReq => { - self.handle_command_csrrequest(transaction.session_mut(), data, encoder)? - } + Commands::AddNOC => self.handle_command_addnoc(transaction, data, encoder)?, + Commands::CSRReq => self.handle_command_csrrequest(transaction, data, encoder)?, Commands::AddTrustedRootCert => { - self.handle_command_addtrustedrootcert(transaction.session_mut(), data)? + self.handle_command_addtrustedrootcert(transaction, data)? } - Commands::AttReq => { - self.handle_command_attrequest(transaction.session(), data, encoder)? + Commands::AttReq => self.handle_command_attrequest(transaction, data, encoder)?, + Commands::CertChainReq => { + self.handle_command_certchainrequest(transaction, data, encoder)? } - Commands::CertChainReq => self.handle_command_certchainrequest(data, encoder)?, Commands::UpdateFabricLabel => { - self.handle_command_updatefablabel(transaction.session(), data, encoder)?; + self.handle_command_updatefablabel(transaction, data, encoder)?; } Commands::RemoveFabric => self.handle_command_rmfabric(transaction, data, encoder)?, } @@ -326,10 +322,13 @@ impl<'a> NocCluster<'a> { fn _handle_command_addnoc( &mut self, - session: &mut Session, + transaction: &mut Transaction, data: &TLVElement, ) -> Result { - let noc_data = session.take_noc_data().ok_or(NocStatus::MissingCsr)?; + let noc_data = transaction + .session_mut() + .take_noc_data() + .ok_or(NocStatus::MissingCsr)?; if !self .failsafe @@ -389,6 +388,7 @@ impl<'a> NocCluster<'a> { self.failsafe.borrow_mut().record_add_noc(fab_idx)?; + transaction.complete(); Ok(fab_idx) } @@ -411,32 +411,36 @@ impl<'a> NocCluster<'a> { fn handle_command_updatefablabel( &mut self, - session: &Session, + transaction: &mut Transaction, data: &TLVElement, encoder: CmdDataEncoder, ) -> Result<(), Error> { cmd_enter!("Update Fabric Label"); let req = UpdateFabricLabelReq::from_tlv(data).map_err(Error::map_invalid_data_type)?; - let (result, fab_idx) = if let SessionMode::Case(c) = session.get_session_mode() { - if self - .fabric_mgr - .borrow_mut() - .set_label( - c.fab_idx, - req.label.as_str().map_err(Error::map_invalid_data_type)?, - ) - .is_err() - { - (NocStatus::LabelConflict, c.fab_idx) + let (result, fab_idx) = + if let SessionMode::Case(c) = transaction.session().get_session_mode() { + if self + .fabric_mgr + .borrow_mut() + .set_label( + c.fab_idx, + req.label.as_str().map_err(Error::map_invalid_data_type)?, + ) + .is_err() + { + (NocStatus::LabelConflict, c.fab_idx) + } else { + (NocStatus::Ok, c.fab_idx) + } } else { - (NocStatus::Ok, c.fab_idx) - } - } else { - // Update Fabric Label not allowed - (NocStatus::InvalidFabricIndex, 0) - }; + // Update Fabric Label not allowed + (NocStatus::InvalidFabricIndex, 0) + }; + + Self::create_nocresponse(encoder, result, fab_idx, "")?; - Self::create_nocresponse(encoder, result, fab_idx, "") + transaction.complete(); + Ok(()) } fn handle_command_rmfabric( @@ -463,13 +467,13 @@ impl<'a> NocCluster<'a> { fn handle_command_addnoc( &mut self, - session: &mut Session, + transaction: &mut Transaction, data: &TLVElement, encoder: CmdDataEncoder, ) -> Result<(), Error> { cmd_enter!("AddNOC"); - let (status, fab_idx) = match self._handle_command_addnoc(session, data) { + let (status, fab_idx) = match self._handle_command_addnoc(transaction, data) { Ok(fab_idx) => (NocStatus::Ok, fab_idx), Err(NocError::Status(status)) => (status, 0), Err(NocError::Error(error)) => Err(error)?, @@ -480,7 +484,7 @@ impl<'a> NocCluster<'a> { fn handle_command_attrequest( &mut self, - session: &Session, + transaction: &mut Transaction, data: &TLVElement, encoder: CmdDataEncoder, ) -> Result<(), Error> { @@ -490,7 +494,7 @@ impl<'a> NocCluster<'a> { info!("Received Attestation Nonce:{:?}", req.str); let mut attest_challenge = [0u8; crypto::SYMM_KEY_LEN_BYTES]; - attest_challenge.copy_from_slice(session.get_att_challenge()); + attest_challenge.copy_from_slice(transaction.session().get_att_challenge()); let mut writer = encoder.with_command(RespCommands::AttReqResp as _)?; @@ -512,11 +516,15 @@ impl<'a> NocCluster<'a> { )?; writer.end_container()?; - writer.complete() + writer.complete()?; + + transaction.complete(); + Ok(()) } fn handle_command_certchainrequest( &mut self, + transaction: &mut Transaction, data: &TLVElement, encoder: CmdDataEncoder, ) -> Result<(), Error> { @@ -535,12 +543,15 @@ impl<'a> NocCluster<'a> { encoder .with_command(RespCommands::CertChainResp as _)? - .set(cmd_data) + .set(cmd_data)?; + + transaction.complete(); + Ok(()) } fn handle_command_csrrequest( &mut self, - session: &mut Session, + transaction: &mut Transaction, data: &TLVElement, encoder: CmdDataEncoder, ) -> Result<(), Error> { @@ -555,7 +566,7 @@ impl<'a> NocCluster<'a> { let noc_keypair = KeyPair::new()?; let mut attest_challenge = [0u8; crypto::SYMM_KEY_LEN_BYTES]; - attest_challenge.copy_from_slice(session.get_att_challenge()); + attest_challenge.copy_from_slice(transaction.session().get_att_challenge()); let mut writer = encoder.with_command(RespCommands::CSRResp as _)?; @@ -576,14 +587,15 @@ impl<'a> NocCluster<'a> { let noc_data = NocData::new(noc_keypair); // Store this in the session data instead of cluster data, so it gets cleared // if the session goes away for some reason - session.set_noc_data(noc_data); + transaction.session_mut().set_noc_data(noc_data); + transaction.complete(); Ok(()) } fn handle_command_addtrustedrootcert( &mut self, - session: &mut Session, + transaction: &mut Transaction, data: &TLVElement, ) -> Result<(), Error> { cmd_enter!("AddTrustedRootCert"); @@ -592,10 +604,13 @@ impl<'a> NocCluster<'a> { } // This may happen on CASE or PASE. For PASE, the existence of NOC Data is necessary - match session.get_session_mode() { + match transaction.session().get_session_mode() { SessionMode::Case(_) => error!("CASE: AddTrustedRootCert handling pending"), // For a CASE Session, we just return success for now, SessionMode::Pase => { - let noc_data = session.get_noc_data::().ok_or(Error::NoSession)?; + let noc_data = transaction + .session_mut() + .get_noc_data::() + .ok_or(Error::NoSession)?; let req = CommonReq::from_tlv(data).map_err(Error::map_invalid_command)?; info!("Received Trusted Cert:{:x?}", req.str); @@ -607,6 +622,7 @@ impl<'a> NocCluster<'a> { _ => (), } + transaction.complete(); Ok(()) } } diff --git a/matter/src/interaction_model/core.rs b/matter/src/interaction_model/core.rs index fd9e130f..371f8e86 100644 --- a/matter/src/interaction_model/core.rs +++ b/matter/src/interaction_model/core.rs @@ -395,7 +395,6 @@ impl<'a> WriteReq<'a> { Interaction::create_status_response(tx, IMStatusCode::Timeout)?; transaction.complete(); - transaction.ctx.exch.close(); Ok(false) } else { @@ -436,7 +435,6 @@ impl<'a> InvReq<'a> { Interaction::create_status_response(tx, IMStatusCode::Timeout)?; transaction.complete(); - transaction.ctx.exch.close(); Ok(false) } else { @@ -738,6 +736,10 @@ where true }; + if transaction.is_complete() { + transaction.exch_mut().close(); + } + Ok(reply) } } @@ -757,6 +759,10 @@ where true }; + if transaction.is_complete() { + transaction.exch_mut().close(); + } + Ok(reply) } } diff --git a/matter/tests/common/im_engine.rs b/matter/tests/common/im_engine.rs index 4cdbf042..86674ecf 100644 --- a/matter/tests/common/im_engine.rs +++ b/matter/tests/common/im_engine.rs @@ -109,7 +109,14 @@ impl<'a> ImInput<'a> { pub type DmHandler<'a> = handler_chain_type!(OnOffCluster, EchoCluster, DescriptorCluster, EchoCluster | RootEndpointHandler<'a>); pub fn matter(mdns: &mut dyn Mdns) -> Matter<'_> { - Matter::new(&BASIC_INFO, mdns, sys_epoch, dummy_rand, sys_utc_calendar) + Matter::new( + &BASIC_INFO, + mdns, + sys_epoch, + dummy_rand, + sys_utc_calendar, + 5540, + ) } /// An Interaction Model Engine to facilitate easy testing From f9536be1e3457fecc43d50f7f9398553b0616c8f Mon Sep 17 00:00:00 2001 From: ivmarkov Date: Tue, 25 Apr 2023 20:23:17 +0000 Subject: [PATCH 21/72] Bugfix: two separate failsafe instances were used --- matter/src/data_model/root_endpoint.rs | 4 ++-- .../src/data_model/sdm/general_commissioning.rs | 16 ++++++++-------- 2 files changed, 10 insertions(+), 10 deletions(-) diff --git a/matter/src/data_model/root_endpoint.rs b/matter/src/data_model/root_endpoint.rs index 341b3de1..7ad87fb6 100644 --- a/matter/src/data_model/root_endpoint.rs +++ b/matter/src/data_model/root_endpoint.rs @@ -30,7 +30,7 @@ use super::{ pub type RootEndpointHandler<'a> = handler_chain_type!( DescriptorCluster, BasicInfoCluster<'a>, - GenCommCluster, + GenCommCluster<'a>, NwCommCluster, AdminCommCluster<'a>, NocCluster<'a>, @@ -107,7 +107,7 @@ pub fn wrap<'a>( .chain( endpoint_id, general_commissioning::ID, - GenCommCluster::new(rand), + GenCommCluster::new(failsafe, rand), ) .chain( endpoint_id, diff --git a/matter/src/data_model/sdm/general_commissioning.rs b/matter/src/data_model/sdm/general_commissioning.rs index 3882f30b..d4d43297 100644 --- a/matter/src/data_model/sdm/general_commissioning.rs +++ b/matter/src/data_model/sdm/general_commissioning.rs @@ -121,17 +121,17 @@ struct FailSafeParams { bread_crumb: u8, } -pub struct GenCommCluster { +pub struct GenCommCluster<'a> { data_ver: Dataver, expiry_len: u16, - failsafe: RefCell, + failsafe: &'a RefCell, } -impl GenCommCluster { - pub fn new(rand: Rand) -> Self { +impl<'a> GenCommCluster<'a> { + pub fn new(failsafe: &'a RefCell, rand: Rand) -> Self { Self { data_ver: Dataver::new(rand), - failsafe: RefCell::new(FailSafe::new()), + failsafe, // TODO: Arch-Specific expiry_len: 120, } @@ -291,7 +291,7 @@ impl GenCommCluster { } } -impl Handler for GenCommCluster { +impl<'a> Handler for GenCommCluster<'a> { fn read(&self, attr: &AttrDetails, encoder: AttrDataEncoder) -> Result<(), Error> { GenCommCluster::read(self, attr, encoder) } @@ -307,9 +307,9 @@ impl Handler for GenCommCluster { } } -impl NonBlockingHandler for GenCommCluster {} +impl<'a> NonBlockingHandler for GenCommCluster<'a> {} -impl ChangeNotifier<()> for GenCommCluster { +impl<'a> ChangeNotifier<()> for GenCommCluster<'a> { fn consume_change(&mut self) -> Option<()> { self.data_ver.consume_change(()) } From f804c21c0b42951c84cec4b46dfd1d98fe3b1efa Mon Sep 17 00:00:00 2001 From: ivmarkov Date: Wed, 26 Apr 2023 05:20:02 +0000 Subject: [PATCH 22/72] Bugfix: fabric adding wrongly started at index 0 --- matter/src/crypto/crypto_mbedtls.rs | 6 ++++++ matter/src/fabric.rs | 30 ++++++++++++++--------------- 2 files changed, 21 insertions(+), 15 deletions(-) diff --git a/matter/src/crypto/crypto_mbedtls.rs b/matter/src/crypto/crypto_mbedtls.rs index 2890fd19..c87e669a 100644 --- a/matter/src/crypto/crypto_mbedtls.rs +++ b/matter/src/crypto/crypto_mbedtls.rs @@ -199,6 +199,12 @@ impl KeyPair { } } +impl core::fmt::Debug for KeyPair { + fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { + f.debug_struct("KeyPair").finish() + } +} + fn convert_r_s_to_asn1_sign(signature: &[u8], mbedtls_sign: &mut [u8]) -> Result { let r = &signature[0..32]; let s = &signature[32..64]; diff --git a/matter/src/fabric.rs b/matter/src/fabric.rs index 6c9d389c..688c56cf 100644 --- a/matter/src/fabric.rs +++ b/matter/src/fabric.rs @@ -53,7 +53,7 @@ const ST_PBKEY: &str = "pubkey"; const ST_PRKEY: &str = "privkey"; #[allow(dead_code)] -#[derive(ToTLV)] +#[derive(Debug, ToTLV)] #[tlvargs(lifetime = "'a", start = 1)] pub struct FabricDescriptor<'a> { root_public_key: OctetStr<'a>, @@ -66,6 +66,7 @@ pub struct FabricDescriptor<'a> { pub fab_idx: Option, } +#[derive(Debug)] pub struct Fabric { node_id: u64, fabric_id: u64, @@ -532,22 +533,21 @@ impl FabricMgr { } pub fn add(&mut self, f: Fabric, mdns_mgr: &mut MdnsMgr) -> Result { - let index = self - .fabrics - .iter() - .skip(1) - .position(|f| f.is_none()) - .ok_or(Error::NoSpace)?; - - self.fabrics[index] = Some(f); - mdns_mgr.publish_service( - &self.fabrics[index].as_ref().unwrap().mdns_service_name, - ServiceMode::Commissioned, - )?; + for i in 1..MAX_SUPPORTED_FABRICS { + if self.fabrics[i].is_none() { + self.fabrics[i] = Some(f); + mdns_mgr.publish_service( + &self.fabrics[i].as_ref().unwrap().mdns_service_name, + ServiceMode::Commissioned, + )?; + + self.changed = true; - self.changed = true; + return Ok(i as u8); + } + } - Ok(index as u8) + Err(Error::NoSpace) } pub fn remove(&mut self, fab_idx: u8, mdns_mgr: &mut MdnsMgr) -> Result<(), Error> { From 996446613812561f77f4b78b25379536878fb0ff Mon Sep 17 00:00:00 2001 From: ivmarkov Date: Wed, 26 Apr 2023 17:28:19 +0000 Subject: [PATCH 23/72] MRP standalone ack messages should not be acknowledged --- matter/src/secure_channel/core.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/matter/src/secure_channel/core.rs b/matter/src/secure_channel/core.rs index e69dca59..fd13206e 100644 --- a/matter/src/secure_channel/core.rs +++ b/matter/src/secure_channel/core.rs @@ -63,7 +63,7 @@ impl<'a> SecureChannel<'a> { info!("Received Data:"); tlv::print_tlv_list(ctx.rx.as_slice()); let (reply, clone_data) = match proto_opcode { - OpCode::MRPStandAloneAck => Ok((true, None)), + OpCode::MRPStandAloneAck => Ok((false, None)), OpCode::PBKDFParamRequest => self .pase .borrow_mut() From 2fc4e6ddcf26d1b4bc03d54b267a9102296618a1 Mon Sep 17 00:00:00 2001 From: ivmarkov Date: Wed, 26 Apr 2023 19:17:19 +0000 Subject: [PATCH 24/72] Root cert buffer too short --- matter/src/cert/mod.rs | 2 +- matter/src/data_model/objects/encoder.rs | 10 ++++++++-- matter/src/interaction_model/core.rs | 9 ++++++++- matter/src/transport/exchange.rs | 6 ++++-- 4 files changed, 21 insertions(+), 6 deletions(-) diff --git a/matter/src/cert/mod.rs b/matter/src/cert/mod.rs index 7af18677..621b28d4 100644 --- a/matter/src/cert/mod.rs +++ b/matter/src/cert/mod.rs @@ -29,7 +29,7 @@ use num_derive::FromPrimitive; pub use self::asn1_writer::ASN1Writer; use self::printer::CertPrinter; -pub const MAX_CERT_TLV_LEN: usize = 300; // TODO +pub const MAX_CERT_TLV_LEN: usize = 1024; // TODO // As per https://datatracker.ietf.org/doc/html/rfc5280 diff --git a/matter/src/data_model/objects/encoder.rs b/matter/src/data_model/objects/encoder.rs index b4066e6b..d068ce7b 100644 --- a/matter/src/data_model/objects/encoder.rs +++ b/matter/src/data_model/objects/encoder.rs @@ -368,10 +368,16 @@ impl<'a, 'b, 'c> CmdDataEncoder<'a, 'b, 'c> { match handler.invoke(transaction, &cmd, &data, encoder) { Ok(()) => cmd.success(&tracker), - Err(error) => cmd.status(error.into()), + Err(error) => { + error!("Error invoking command: {}", error); + cmd.status(error.into()) + } } } - Err(status) => Some(status), + Err(status) => { + error!("Error invoking command: {:?}", status); + Some(status) + } }; if let Some(status) = status { diff --git a/matter/src/interaction_model/core.rs b/matter/src/interaction_model/core.rs index 371f8e86..1e9827bc 100644 --- a/matter/src/interaction_model/core.rs +++ b/matter/src/interaction_model/core.rs @@ -470,6 +470,8 @@ impl<'a> InvReq<'a> { } pub fn complete(self, tx: &mut Packet, _transaction: &mut Transaction) -> Result { + let suppress = self.suppress_response.unwrap_or_default(); + let mut tw = TLVWriter::new(tx.get_writebuf()?); if self.inv_requests.is_some() { @@ -478,7 +480,12 @@ impl<'a> InvReq<'a> { tw.end_container()?; - Ok(true) + Ok(if suppress { + error!("Supress response is set, is this the expected handling?"); + false + } else { + true + }) } } diff --git a/matter/src/transport/exchange.rs b/matter/src/transport/exchange.rs index 333eab3f..a25baba2 100644 --- a/matter/src/transport/exchange.rs +++ b/matter/src/transport/exchange.rs @@ -25,6 +25,7 @@ use crate::error::Error; use crate::interaction_model::core::{ResumeReadReq, ResumeSubscribeReq}; use crate::secure_channel; use crate::secure_channel::case::CaseSession; +use crate::tlv::print_tlv_list; use crate::utils::epoch::Epoch; use crate::utils::rand::Rand; @@ -204,13 +205,14 @@ impl Exchange { return Ok(false); } - trace!("payload: {:x?}", tx.as_mut_slice()); + trace!("payload: {:x?}", tx.as_slice()); info!( - "{} with proto id: {} opcode: {}", + "{} with proto id: {} opcode: {}, tlv:\n", "Sending".blue(), tx.get_proto_id(), tx.get_proto_opcode(), ); + print_tlv_list(tx.as_slice()); tx.proto.exch_id = self.id; if self.role == Role::Initiator { From 09a523fc508455f9502f09b8ba58e1627258ea51 Mon Sep 17 00:00:00 2001 From: ivmarkov Date: Fri, 28 Apr 2023 05:26:36 +0000 Subject: [PATCH 25/72] TX packets are reused; need way to reset them --- matter/src/interaction_model/core.rs | 2 ++ matter/src/transport/packet.rs | 13 +++++++++++++ 2 files changed, 15 insertions(+) diff --git a/matter/src/interaction_model/core.rs b/matter/src/interaction_model/core.rs index 1e9827bc..abf76a48 100644 --- a/matter/src/interaction_model/core.rs +++ b/matter/src/interaction_model/core.rs @@ -271,6 +271,8 @@ impl<'a> Interaction<'a> { transaction: &mut Transaction, ) -> Result, Error> { if let Some(interaction) = Self::new(rx, transaction)? { + tx.reset(); + let initiated = match &interaction { Interaction::Read(req) => req.initiate(tx, transaction)?, Interaction::Write(req) => req.initiate(tx, transaction)?, diff --git a/matter/src/transport/packet.rs b/matter/src/transport/packet.rs index a86bf697..d56485f1 100644 --- a/matter/src/transport/packet.rs +++ b/matter/src/transport/packet.rs @@ -78,6 +78,19 @@ impl<'a> Packet<'a> { } } + pub fn reset(&mut self) -> () { + if let Direction::Tx(wb) = &mut self.data { + wb.reset(); + wb.reserve(Packet::HDR_RESERVE).unwrap(); + + self.plain = Default::default(); + self.proto = Default::default(); + self.peer = Address::default(); + + self.proto.set_reliable(); + } + } + pub fn as_slice(&self) -> &[u8] { match &self.data { Direction::Rx(pb, _) => pb.as_slice(), From 2a57ecbd87a3d5d77f14c54b030dd78ea97d2dca Mon Sep 17 00:00:00 2001 From: ivmarkov Date: Fri, 28 Apr 2023 06:13:54 +0000 Subject: [PATCH 26/72] Bugfix: only report devtype for the queried endpoint --- matter/src/data_model/system_model/descriptor.rs | 14 +++++++++++--- 1 file changed, 11 insertions(+), 3 deletions(-) diff --git a/matter/src/data_model/system_model/descriptor.rs b/matter/src/data_model/system_model/descriptor.rs index 2df17f57..b434586d 100644 --- a/matter/src/data_model/system_model/descriptor.rs +++ b/matter/src/data_model/system_model/descriptor.rs @@ -71,7 +71,12 @@ impl DescriptorCluster { } else { match attr.attr_id.try_into()? { Attributes::DeviceTypeList => { - self.encode_devtype_list(attr.node, AttrDataWriter::TAG, &mut writer)?; + self.encode_devtype_list( + attr.node, + attr.endpoint_id, + AttrDataWriter::TAG, + &mut writer, + )?; writer.complete() } Attributes::ServerList => { @@ -111,13 +116,16 @@ impl DescriptorCluster { fn encode_devtype_list( &self, node: &Node, + endpoint_id: u16, tag: TagType, tw: &mut TLVWriter, ) -> Result<(), Error> { tw.start_array(tag)?; for endpoint in node.endpoints { - let dev_type = endpoint.device_type; - dev_type.to_tlv(tw, TagType::Anonymous)?; + if endpoint.id == endpoint_id { + let dev_type = endpoint.device_type; + dev_type.to_tlv(tw, TagType::Anonymous)?; + } } tw.end_container() From 635be2c35a0f285acfd94574557cfb8ce3d8ab43 Mon Sep 17 00:00:00 2001 From: ivmarkov Date: Fri, 28 Apr 2023 06:38:52 +0000 Subject: [PATCH 27/72] Error log on arm failure --- matter/src/data_model/sdm/failsafe.rs | 1 + 1 file changed, 1 insertion(+) diff --git a/matter/src/data_model/sdm/failsafe.rs b/matter/src/data_model/sdm/failsafe.rs index 54b22e6a..5008c9fa 100644 --- a/matter/src/data_model/sdm/failsafe.rs +++ b/matter/src/data_model/sdm/failsafe.rs @@ -61,6 +61,7 @@ impl FailSafe { } State::Armed(c) => { if c.session_mode != session_mode { + error!("Received Fail-Safe Arm with different session modes; current {:?}, incoming {:?}", c.session_mode, session_mode); return Err(Error::Invalid); } // re-arm From 076ba06e079b4c404689bfd8e1dcea4ede81601a Mon Sep 17 00:00:00 2001 From: ivmarkov Date: Fri, 28 Apr 2023 08:41:41 +0000 Subject: [PATCH 28/72] Bugfix: missing descriptor cluster --- examples/onoff_light/src/main.rs | 19 +++++++++++++------ 1 file changed, 13 insertions(+), 6 deletions(-) diff --git a/examples/onoff_light/src/main.rs b/examples/onoff_light/src/main.rs index 1748b599..7ffc70e5 100644 --- a/examples/onoff_light/src/main.rs +++ b/examples/onoff_light/src/main.rs @@ -25,6 +25,7 @@ use matter::data_model::device_types::DEV_TYPE_ON_OFF_LIGHT; use matter::data_model::objects::*; use matter::data_model::root_endpoint; use matter::data_model::sdm::dev_att::DevAttDataFetcher; +use matter::data_model::system_model::descriptor; use matter::interaction_model::core::InteractionModel; use matter::secure_channel::spake2p::VerifierData; use matter::transport::{ @@ -95,7 +96,7 @@ fn main() { Endpoint { id: 1, device_type: DEV_TYPE_ON_OFF_LIGHT, - clusters: &[cluster_on_off::CLUSTER], + clusters: &[descriptor::CLUSTER, cluster_on_off::CLUSTER], }, ], }; @@ -118,9 +119,15 @@ fn main() { } fn handler<'a>(matter: &'a Matter<'a>, dev_att: &'a dyn DevAttDataFetcher) -> impl Handler + 'a { - root_endpoint::handler(0, dev_att, matter).chain( - 1, - cluster_on_off::ID, - cluster_on_off::OnOffCluster::new(*matter.borrow()), - ) + root_endpoint::handler(0, dev_att, matter) + .chain( + 1, + descriptor::ID, + descriptor::DescriptorCluster::new(*matter.borrow()), + ) + .chain( + 1, + cluster_on_off::ID, + cluster_on_off::OnOffCluster::new(*matter.borrow()), + ) } From e8e847cea6e924b091f3feee4844167a51afd6bc Mon Sep 17 00:00:00 2001 From: ivmarkov Date: Fri, 28 Apr 2023 10:42:55 +0000 Subject: [PATCH 29/72] Bugfix: subscription_id was not sent --- matter/src/data_model/core.rs | 50 +++++----------- matter/src/data_model/sdm/noc.rs | 6 +- matter/src/interaction_model/core.rs | 90 ++++++++++++++++++---------- matter/src/transport/exchange.rs | 15 +++++ matter/src/transport/mgr.rs | 3 +- 5 files changed, 93 insertions(+), 71 deletions(-) diff --git a/matter/src/data_model/core.rs b/matter/src/data_model/core.rs index b53f955d..20efeb77 100644 --- a/matter/src/data_model/core.rs +++ b/matter/src/data_model/core.rs @@ -15,26 +15,17 @@ * limitations under the License. */ -use core::{ - cell::RefCell, - sync::atomic::{AtomicU32, Ordering}, -}; +use core::cell::RefCell; use super::objects::*; use crate::{ acl::{Accessor, AclMgr}, error::*, - interaction_model::{ - core::{Interaction, Transaction}, - messages::msg::SubscribeResp, - }, - tlv::{TLVWriter, TagType, ToTLV}, + interaction_model::core::{Interaction, Transaction}, + tlv::TLVWriter, transport::packet::Packet, }; -// TODO: For now... -static SUBS_ID: AtomicU32 = AtomicU32::new(1); - pub struct DataModel<'a, T> { pub acl_mgr: &'a RefCell, pub node: &'a Node<'a>, @@ -120,19 +111,12 @@ impl<'a, T> DataModel<'a, T> { Interaction::ResumeSubscribe(req) => { let mut resume_path = None; - if req.resume_path.is_some() { - for item in self.node.resume_subscribing_read(&req, &accessor) { - if let Some(path) = - AttrDataEncoder::handle_read(item, &self.handler, &mut tw)? - { - resume_path = Some(path); - break; - } + for item in self.node.resume_subscribing_read(&req, &accessor) { + if let Some(path) = AttrDataEncoder::handle_read(item, &self.handler, &mut tw)? + { + resume_path = Some(path); + break; } - } else { - // TODO - let resp = SubscribeResp::new(SUBS_ID.fetch_add(1, Ordering::SeqCst), 40); - resp.to_tlv(&mut tw, TagType::Anonymous)?; } req.complete(tx, transaction, resume_path) @@ -215,19 +199,13 @@ impl<'a, T> DataModel<'a, T> { Interaction::ResumeSubscribe(req) => { let mut resume_path = None; - if req.resume_path.is_some() { - for item in self.node.resume_subscribing_read(&req, &accessor) { - if let Some(path) = - AttrDataEncoder::handle_read_async(item, &self.handler, &mut tw).await? - { - resume_path = Some(path); - break; - } + for item in self.node.resume_subscribing_read(&req, &accessor) { + if let Some(path) = + AttrDataEncoder::handle_read_async(item, &self.handler, &mut tw).await? + { + resume_path = Some(path); + break; } - } else { - // TODO - let resp = SubscribeResp::new(SUBS_ID.fetch_add(1, Ordering::SeqCst), 40); - resp.to_tlv(&mut tw, TagType::Anonymous)?; } req.complete(tx, transaction, resume_path) diff --git a/matter/src/data_model/sdm/noc.rs b/matter/src/data_model/sdm/noc.rs index 3bf6d858..634ba85f 100644 --- a/matter/src/data_model/sdm/noc.rs +++ b/matter/src/data_model/sdm/noc.rs @@ -388,7 +388,6 @@ impl<'a> NocCluster<'a> { self.failsafe.borrow_mut().record_add_noc(fab_idx)?; - transaction.complete(); Ok(fab_idx) } @@ -479,7 +478,10 @@ impl<'a> NocCluster<'a> { Err(NocError::Error(error)) => Err(error)?, }; - Self::create_nocresponse(encoder, status, fab_idx, "") + Self::create_nocresponse(encoder, status, fab_idx, "")?; + transaction.complete(); + + Ok(()) } fn handle_command_attrequest( diff --git a/matter/src/interaction_model/core.rs b/matter/src/interaction_model/core.rs index abf76a48..67e9181d 100644 --- a/matter/src/interaction_model/core.rs +++ b/matter/src/interaction_model/core.rs @@ -15,6 +15,7 @@ * limitations under the License. */ +use core::sync::atomic::{AtomicU32, Ordering}; use core::time::Duration; use crate::{ @@ -35,7 +36,7 @@ use num_derive::FromPrimitive; use super::messages::{ ib::{AttrPath, DataVersionFilter}, - msg::{self, InvReq, ReadReq, StatusResp, SubscribeReq, TimedReq, WriteReq}, + msg::{self, InvReq, ReadReq, StatusResp, SubscribeReq, SubscribeResp, TimedReq, WriteReq}, GenericPath, }; @@ -206,6 +207,9 @@ const MAX_RESUME_DATAVER_FILTERS: usize = 128; // the end of long reads. const LONG_READS_TLV_RESERVE_SIZE: usize = 24; +// TODO: For now... +static SUBS_ID: AtomicU32 = AtomicU32::new(1); + pub enum Interaction<'a> { Read(ReadReq<'a>), Write(WriteReq<'a>), @@ -511,8 +515,13 @@ impl TimedReq { } impl<'a> SubscribeReq<'a> { - fn suspend(&self, resume_path: Option) -> ResumeSubscribeReq { + fn suspend( + &self, + resume_path: Option, + subscription_id: u32, + ) -> ResumeSubscribeReq { ResumeSubscribeReq { + subscription_id, paths: self .attr_requests .iter() @@ -531,7 +540,7 @@ impl<'a> SubscribeReq<'a> { } } - fn initiate(&self, tx: &mut Packet, _transaction: &mut Transaction) -> Result { + fn initiate(&self, tx: &mut Packet, transaction: &mut Transaction) -> Result { tx.set_proto_id(PROTO_ID_INTERACTION_MODEL as u16); tx.set_proto_opcode(OpCode::ReportData as u8); @@ -539,6 +548,14 @@ impl<'a> SubscribeReq<'a> { tw.start_struct(TagType::Anonymous)?; + let subscription_id = SUBS_ID.fetch_add(1, Ordering::SeqCst); + transaction.exch_mut().set_subscription_id(subscription_id); + + tw.u32( + TagType::Context(msg::ReportDataTag::SubscriptionId as u8), + subscription_id, + )?; + if self.attr_requests.is_some() { tw.start_array(TagType::Context(msg::ReportDataTag::AttributeReports as u8))?; } @@ -565,9 +582,11 @@ impl<'a> SubscribeReq<'a> { )?; } + let subscription_id = transaction.exch_mut().take_subscription_id().unwrap(); + transaction .exch_mut() - .set_suspended_subscribe_req(self.suspend(resume_path)); + .set_suspended_subscribe_req(self.suspend(resume_path, subscription_id)); tw.bool( TagType::Context(msg::ReportDataTag::SupressResponse as u8), @@ -640,6 +659,7 @@ impl ResumeReadReq { } pub struct ResumeSubscribeReq { + pub subscription_id: u32, pub paths: heapless::Vec, pub filters: heapless::Vec, pub fabric_filtered: bool, @@ -660,15 +680,28 @@ impl ResumeSubscribeReq { tw.start_struct(TagType::Anonymous)?; + tw.u32( + TagType::Context(msg::ReportDataTag::SubscriptionId as u8), + self.subscription_id, + )?; + tw.start_array(TagType::Context(msg::ReportDataTag::AttributeReports as u8))?; + + Ok(true) } else { tx.set_proto_opcode(OpCode::SubscribeResponse as u8); - // let mut tw = TLVWriter::new(tx.get_writebuf()?); - // tw.start_struct(TagType::Anonymous)?; - } + let mut tw = TLVWriter::new(tx.get_writebuf()?); - Ok(true) + tw.start_struct(TagType::Anonymous)?; + + let resp = SubscribeResp::new(self.subscription_id, 40); + resp.to_tlv(&mut tw, TagType::Anonymous)?; + + tw.end_container()?; + + Ok(false) + } } pub fn complete( @@ -677,40 +710,33 @@ impl ResumeSubscribeReq { transaction: &mut Transaction, resume_path: Option, ) -> Result { - if self.resume_path.is_none() && resume_path.is_some() { - panic!("Cannot resume subscribe"); + if self.resume_path.is_none() { + // Should not get here as initiate() should've sent the subscribe response already + panic!("Subscription was already processed"); } - if self.resume_path.is_some() { - // Completing a ReportData message - let mut tw = ReadReq::restore_long_read_space(tx)?; + // Completing a ReportData message - tw.end_container()?; + let mut tw = ReadReq::restore_long_read_space(tx)?; - if resume_path.is_some() { - tw.bool( - TagType::Context(msg::ReportDataTag::MoreChunkedMsgs as u8), - true, - )?; - } + tw.end_container()?; + if resume_path.is_some() { tw.bool( - TagType::Context(msg::ReportDataTag::SupressResponse as u8), - false, + TagType::Context(msg::ReportDataTag::MoreChunkedMsgs as u8), + true, )?; + } - tw.end_container()?; - - self.resume_path = resume_path; - transaction.exch_mut().set_suspended_subscribe_req(self); - } else { - // Completing a SubscribeResponse message + tw.bool( + TagType::Context(msg::ReportDataTag::SupressResponse as u8), + false, + )?; - // let mut tw = TLVWriter::new(tx.get_writebuf()?); - // tw.end_container()?; + tw.end_container()?; - transaction.complete(); - } + self.resume_path = resume_path; + transaction.exch_mut().set_suspended_subscribe_req(self); Ok(true) } diff --git a/matter/src/transport/exchange.rs b/matter/src/transport/exchange.rs index a25baba2..5d7a79c8 100644 --- a/matter/src/transport/exchange.rs +++ b/matter/src/transport/exchange.rs @@ -71,6 +71,7 @@ pub enum DataOption { CaseSession(CaseSession), Time(Duration), SuspendedReadReq(ResumeReadReq), + SubscriptionId(u32), SuspendedSubscibeReq(ResumeSubscribeReq), #[default] None, @@ -168,6 +169,20 @@ impl Exchange { } } + pub fn set_subscription_id(&mut self, id: u32) { + self.data = DataOption::SubscriptionId(id); + } + + pub fn take_subscription_id(&mut self) -> Option { + let old = core::mem::replace(&mut self.data, DataOption::None); + if let DataOption::SubscriptionId(id) = old { + Some(id) + } else { + self.data = old; + None + } + } + pub fn set_suspended_subscribe_req(&mut self, req: ResumeSubscribeReq) { self.data = DataOption::SuspendedSubscibeReq(req); } diff --git a/matter/src/transport/mgr.rs b/matter/src/transport/mgr.rs index e33e30e9..0db63904 100644 --- a/matter/src/transport/mgr.rs +++ b/matter/src/transport/mgr.rs @@ -70,6 +70,7 @@ impl<'r, 'a, 'p> RecvCompletion<'r, 'a, 'p> { fn maybe_next_action(&mut self) -> Result>>, Error> { self.mgr.exch_mgr.purge(); + self.tx.reset(); let (state, next) = match core::mem::replace(&mut self.state, RecvState::New) { RecvState::New => { @@ -108,7 +109,7 @@ impl<'r, 'a, 'p> RecvCompletion<'r, 'a, 'p> { } } Ok(None) => (RecvState::Ack, None), - Err(Error::Duplicate) => (RecvState::Ack, Some(None)), + Err(Error::Duplicate) => (RecvState::Ack, None), Err(Error::NoSpace) => (RecvState::EvictSession, None), Err(err) => Err(err)?, }, From 4b39884f6f09e8f730b61843fc0602a64f4cf84b Mon Sep 17 00:00:00 2001 From: ivmarkov Date: Fri, 28 Apr 2023 10:50:28 +0000 Subject: [PATCH 30/72] Bugfix: unnecessary struct container --- matter/src/interaction_model/core.rs | 4 ---- 1 file changed, 4 deletions(-) diff --git a/matter/src/interaction_model/core.rs b/matter/src/interaction_model/core.rs index 67e9181d..2ce9c82b 100644 --- a/matter/src/interaction_model/core.rs +++ b/matter/src/interaction_model/core.rs @@ -693,13 +693,9 @@ impl ResumeSubscribeReq { let mut tw = TLVWriter::new(tx.get_writebuf()?); - tw.start_struct(TagType::Anonymous)?; - let resp = SubscribeResp::new(self.subscription_id, 40); resp.to_tlv(&mut tw, TagType::Anonymous)?; - tw.end_container()?; - Ok(false) } } From 86e01a0a1bd28442aeae9b77f663bd1a2614aaa0 Mon Sep 17 00:00:00 2001 From: ivmarkov Date: Sat, 29 Apr 2023 10:03:34 +0000 Subject: [PATCH 31/72] Simple persistance via TLV --- examples/onoff_light/src/main.rs | 23 +- matter/src/acl.rs | 62 +--- matter/src/core.rs | 18 + matter/src/crypto/crypto_dummy.rs | 8 +- matter/src/crypto/crypto_esp_mbedtls.rs | 9 +- matter/src/crypto/mod.rs | 36 ++ matter/src/fabric.rs | 430 ++++-------------------- matter/src/group_keys.rs | 14 +- matter/src/persist.rs | 221 ++---------- matter/src/tlv/traits.rs | 63 +++- 10 files changed, 250 insertions(+), 634 deletions(-) diff --git a/examples/onoff_light/src/main.rs b/examples/onoff_light/src/main.rs index 7ffc70e5..b6d2588b 100644 --- a/examples/onoff_light/src/main.rs +++ b/examples/onoff_light/src/main.rs @@ -27,6 +27,7 @@ use matter::data_model::root_endpoint; use matter::data_model::sdm::dev_att::DevAttDataFetcher; use matter::data_model::system_model::descriptor; use matter::interaction_model::core::InteractionModel; +use matter::persist; use matter::secure_channel::spake2p::VerifierData; use matter::transport::{ mgr::RecvAction, mgr::TransportMgr, packet::MAX_RX_BUF_SIZE, packet::MAX_TX_BUF_SIZE, @@ -56,6 +57,18 @@ fn main() { let dev_att = dev_att::HardCodedDevAtt::new(); + let psm = persist::FilePsm::new(std::env::temp_dir().join("matter-iot")).unwrap(); + + let mut buf = [0; 4096]; + + if let Some(data) = psm.load("fabrics", &mut buf).unwrap() { + matter.load_fabrics(data).unwrap(); + } + + if let Some(data) = psm.load("acls", &mut buf).unwrap() { + matter.load_acls(data).unwrap(); + } + matter .start::<4096>( CommissioningData { @@ -63,7 +76,7 @@ fn main() { verifier: VerifierData::new_with_pw(123456, *matter.borrow()), discriminator: 250, }, - &mut [0; 4096], + &mut buf, ) .unwrap(); @@ -114,6 +127,14 @@ fn main() { } } } + + if let Some(data) = matter.store_fabrics(&mut buf).unwrap() { + psm.store("fabrics", data).unwrap(); + } + + if let Some(data) = matter.store_acls(&mut buf).unwrap() { + psm.store("acls", data).unwrap(); + } } }); } diff --git a/matter/src/acl.rs b/matter/src/acl.rs index d73ce47f..dea592b3 100644 --- a/matter/src/acl.rs +++ b/matter/src/acl.rs @@ -22,7 +22,6 @@ use crate::{ error::Error, fabric, interaction_model::messages::GenericPath, - persist::Psm, tlv::{FromTLV, TLVElement, TLVList, TLVWriter, TagType, ToTLV}, transport::session::{Session, SessionMode, MAX_CAT_IDS_PER_NOC}, utils::writebuf::WriteBuf, @@ -390,10 +389,8 @@ impl AclEntry { } const MAX_ACL_ENTRIES: usize = ENTRIES_PER_FABRIC * fabric::MAX_SUPPORTED_FABRICS; -type AclEntries = [Option; MAX_ACL_ENTRIES]; -const ACL_KV_ENTRY: &str = "acl"; -const ACL_KV_MAX_SIZE: usize = 300; +type AclEntries = [Option; MAX_ACL_ENTRIES]; pub struct AclMgr { entries: AclEntries, @@ -505,30 +502,8 @@ impl AclMgr { false } - pub fn store(&mut self, mut psm: T) -> Result<(), Error> - where - T: Psm, - { - if self.changed { - let mut buf = [0u8; ACL_KV_MAX_SIZE]; - let mut wb = WriteBuf::new(&mut buf); - let mut tw = TLVWriter::new(&mut wb); - self.entries.to_tlv(&mut tw, TagType::Anonymous)?; - psm.set_kv_slice(ACL_KV_ENTRY, wb.as_slice())?; - - self.changed = false; - } - - Ok(()) - } - - pub fn load(&mut self, psm: T) -> Result<(), Error> - where - T: Psm, - { - let mut buf = [0u8; ACL_KV_MAX_SIZE]; - let acl_tlvs = psm.get_kv_slice(ACL_KV_ENTRY, &mut buf)?; - let root = TLVList::new(acl_tlvs).iter().next().ok_or(Error::Invalid)?; + pub fn load(&mut self, data: &[u8]) -> Result<(), Error> { + let root = TLVList::new(data).iter().next().ok_or(Error::Invalid)?; self.entries = AclEntries::from_tlv(&root)?; self.changed = false; @@ -536,37 +511,20 @@ impl AclMgr { Ok(()) } - #[cfg(feature = "nightly")] - pub async fn store_async(&mut self, mut psm: T) -> Result<(), Error> - where - T: crate::persist::asynch::AsyncPsm, - { + pub fn store<'a>(&mut self, buf: &'a mut [u8]) -> Result, Error> { if self.changed { - let mut buf = [0u8; ACL_KV_MAX_SIZE]; - let mut wb = WriteBuf::new(&mut buf); + let mut wb = WriteBuf::new(buf); let mut tw = TLVWriter::new(&mut wb); self.entries.to_tlv(&mut tw, TagType::Anonymous)?; - psm.set_kv_slice(ACL_KV_ENTRY, wb.as_slice()).await?; self.changed = false; - } - Ok(()) - } + let len = tw.get_tail(); - #[cfg(feature = "nightly")] - pub async fn load_async(&mut self, psm: T) -> Result<(), Error> - where - T: crate::persist::asynch::AsyncPsm, - { - let mut buf = [0u8; ACL_KV_MAX_SIZE]; - let acl_tlvs = psm.get_kv_slice(ACL_KV_ENTRY, &mut buf).await?; - let root = TLVList::new(acl_tlvs).iter().next().ok_or(Error::Invalid)?; - - self.entries = AclEntries::from_tlv(&root)?; - self.changed = false; - - Ok(()) + Ok(Some(&buf[..len])) + } else { + Ok(None) + } } /// Traverse fabric specific entries to find the index diff --git a/matter/src/core.rs b/matter/src/core.rs index 0939a4a5..e2e6b597 100644 --- a/matter/src/core.rs +++ b/matter/src/core.rs @@ -98,6 +98,24 @@ impl<'a> Matter<'a> { self.dev_det } + pub fn load_fabrics(&self, data: &[u8]) -> Result<(), Error> { + self.fabric_mgr + .borrow_mut() + .load(data, &mut self.mdns_mgr.borrow_mut()) + } + + pub fn load_acls(&self, data: &[u8]) -> Result<(), Error> { + self.acl_mgr.borrow_mut().load(data) + } + + pub fn store_fabrics<'b>(&self, buf: &'b mut [u8]) -> Result, Error> { + self.fabric_mgr.borrow_mut().store(buf) + } + + pub fn store_acls<'b>(&self, buf: &'b mut [u8]) -> Result, Error> { + self.acl_mgr.borrow_mut().store(buf) + } + pub fn start( &self, dev_comm: CommissioningData, diff --git a/matter/src/crypto/crypto_dummy.rs b/matter/src/crypto/crypto_dummy.rs index f193b205..acdae098 100644 --- a/matter/src/crypto/crypto_dummy.rs +++ b/matter/src/crypto/crypto_dummy.rs @@ -68,8 +68,6 @@ impl KeyPair { } pub fn new_from_components(_pub_key: &[u8], _priv_key: &[u8]) -> Result { - error!("This API should never get called"); - Ok(Self {}) } @@ -85,13 +83,11 @@ impl KeyPair { } pub fn get_public_key(&self, _pub_key: &mut [u8]) -> Result { - error!("This API should never get called"); - Err(Error::Invalid) + Ok(0) } pub fn get_private_key(&self, _pub_key: &mut [u8]) -> Result { - error!("This API should never get called"); - Err(Error::Invalid) + Ok(0) } pub fn derive_secret(self, _peer_pub_key: &[u8], _secret: &mut [u8]) -> Result { diff --git a/matter/src/crypto/crypto_esp_mbedtls.rs b/matter/src/crypto/crypto_esp_mbedtls.rs index fe723370..4eee8a76 100644 --- a/matter/src/crypto/crypto_esp_mbedtls.rs +++ b/matter/src/crypto/crypto_esp_mbedtls.rs @@ -70,8 +70,6 @@ impl KeyPair { } pub fn new_from_components(_pub_key: &[u8], priv_key: &[u8]) -> Result { - error!("This API should never get called"); - Ok(Self {}) } @@ -87,8 +85,11 @@ impl KeyPair { } pub fn get_public_key(&self, _pub_key: &mut [u8]) -> Result { - error!("This API should never get called"); - Err(Error::Invalid) + Ok(0) + } + + pub fn get_private_key(&self, priv_key: &mut [u8]) -> Result { + Ok(0) } pub fn derive_secret(self, _peer_pub_key: &[u8], _secret: &mut [u8]) -> Result { diff --git a/matter/src/crypto/mod.rs b/matter/src/crypto/mod.rs index 5c73ff2d..27ba187c 100644 --- a/matter/src/crypto/mod.rs +++ b/matter/src/crypto/mod.rs @@ -14,6 +14,10 @@ * See the License for the specific language governing permissions and * limitations under the License. */ +use crate::{ + error::Error, + tlv::{FromTLV, TLVWriter, TagType, ToTLV}, +}; pub const SYMM_KEY_LEN_BITS: usize = 128; pub const SYMM_KEY_LEN_BYTES: usize = SYMM_KEY_LEN_BITS / 8; @@ -68,6 +72,38 @@ pub mod crypto_dummy; )))] pub use self::crypto_dummy::*; +impl<'a> FromTLV<'a> for KeyPair { + fn from_tlv(t: &crate::tlv::TLVElement<'a>) -> Result + where + Self: Sized, + { + t.confirm_array()?.enter(); + + if let Some(mut array) = t.enter() { + let pub_key = array.next().ok_or(Error::Invalid)?.slice()?; + let priv_key = array.next().ok_or(Error::Invalid)?.slice()?; + + KeyPair::new_from_components(pub_key, priv_key) + } else { + Err(Error::Invalid) + } + } +} + +impl ToTLV for KeyPair { + fn to_tlv(&self, tw: &mut TLVWriter, tag: TagType) -> Result<(), Error> { + let mut buf = [0; 1024]; // TODO + + tw.start_array(tag)?; + + let size = self.get_public_key(&mut buf)?; + tw.str16(TagType::Anonymous, &buf[..size])?; + + let size = self.get_private_key(&mut buf)?; + tw.str16(TagType::Anonymous, &buf[..size]) + } +} + #[cfg(test)] mod tests { use crate::error::Error; diff --git a/matter/src/fabric.rs b/matter/src/fabric.rs index 688c56cf..6f3ff0ee 100644 --- a/matter/src/fabric.rs +++ b/matter/src/fabric.rs @@ -18,7 +18,8 @@ use core::fmt::Write; use byteorder::{BigEndian, ByteOrder, LittleEndian}; -use log::{error, info}; +use heapless::{String, Vec}; +use log::info; use crate::{ cert::{Cert, MAX_CERT_TLV_LEN}, @@ -26,32 +27,12 @@ use crate::{ error::Error, group_keys::KeySet, mdns::{MdnsMgr, ServiceMode}, - persist::Psm, - tlv::{OctetStr, TLVWriter, TagType, ToTLV, UtfStr}, + tlv::{FromTLV, OctetStr, TLVElement, TLVList, TLVWriter, TagType, ToTLV, UtfStr}, + utils::writebuf::WriteBuf, }; const COMPRESSED_FABRIC_ID_LEN: usize = 8; -macro_rules! fb_key { - ($index:ident, $key:ident, $buf:expr) => {{ - use core::fmt::Write; - - $buf = "".into(); - write!(&mut $buf, "fb{}{}", $index, $key).unwrap(); - - &$buf - }}; -} - -const ST_VID: &str = "vid"; -const ST_RCA: &str = "rca"; -const ST_ICA: &str = "ica"; -const ST_NOC: &str = "noc"; -const ST_IPK: &str = "ipk"; -const ST_LBL: &str = "label"; -const ST_PBKEY: &str = "pubkey"; -const ST_PRKEY: &str = "privkey"; - #[allow(dead_code)] #[derive(Debug, ToTLV)] #[tlvargs(lifetime = "'a", start = 1)] @@ -66,18 +47,18 @@ pub struct FabricDescriptor<'a> { pub fab_idx: Option, } -#[derive(Debug)] +#[derive(Debug, ToTLV, FromTLV)] pub struct Fabric { node_id: u64, fabric_id: u64, vendor_id: u16, key_pair: KeyPair, - pub root_ca: heapless::Vec, - pub icac: Option>, - pub noc: heapless::Vec, + pub root_ca: Vec, + pub icac: Option>, + pub noc: Vec, pub ipk: KeySet, - label: heapless::String<32>, - mdns_service_name: heapless::String<33>, + label: String<32>, + mdns_service_name: String<33>, } impl Fabric { @@ -199,234 +180,14 @@ impl Fabric { Ok(desc) } - - fn store(&self, index: usize, mut psm: T) -> Result<(), Error> - where - T: Psm, - { - let mut _kb = heapless::String::<32>::new(); - - psm.set_kv_slice(fb_key!(index, ST_RCA, _kb), &self.root_ca)?; - psm.set_kv_slice( - fb_key!(index, ST_ICA, _kb), - self.icac.as_deref().unwrap_or(&[]), - )?; - - psm.set_kv_slice(fb_key!(index, ST_NOC, _kb), &self.noc)?; - psm.set_kv_slice(fb_key!(index, ST_IPK, _kb), self.ipk.epoch_key())?; - psm.set_kv_slice(fb_key!(index, ST_LBL, _kb), self.label.as_bytes())?; - - let mut buf = [0_u8; crypto::EC_POINT_LEN_BYTES]; - let len = self.key_pair.get_public_key(&mut buf)?; - let key = &buf[..len]; - psm.set_kv_slice(fb_key!(index, ST_PBKEY, _kb), key)?; - - let mut buf = [0_u8; crypto::BIGNUM_LEN_BYTES]; - let len = self.key_pair.get_private_key(&mut buf)?; - let key = &buf[..len]; - psm.set_kv_slice(fb_key!(index, ST_PRKEY, _kb), key)?; - - psm.set_kv_u64(fb_key!(index, ST_VID, _kb), self.vendor_id.into())?; - Ok(()) - } - - fn load(index: usize, psm: T) -> Result - where - T: Psm, - { - let mut _kb = heapless::String::<32>::new(); - - let mut buf = [0u8; MAX_CERT_TLV_LEN]; - - let root_ca = - heapless::Vec::from_slice(psm.get_kv_slice(fb_key!(index, ST_RCA, _kb), &mut buf)?) - .unwrap(); - - let icac = psm.get_kv_slice(fb_key!(index, ST_ICA, _kb), &mut buf)?; - let icac = if !icac.is_empty() { - Some(heapless::Vec::from_slice(icac).unwrap()) - } else { - None - }; - - let noc = - heapless::Vec::from_slice(psm.get_kv_slice(fb_key!(index, ST_NOC, _kb), &mut buf)?) - .unwrap(); - - let label = psm.get_kv_slice(fb_key!(index, ST_LBL, _kb), &mut buf)?; - let label: heapless::String<32> = core::str::from_utf8(label) - .map_err(|_| { - error!("Couldn't read label"); - Error::Invalid - })? - .into(); - - let ipk = psm.get_kv_slice(fb_key!(index, ST_IPK, _kb), &mut buf)?; - - let mut buf = [0_u8; crypto::EC_POINT_LEN_BYTES]; - let pub_key = psm.get_kv_slice(fb_key!(index, ST_PBKEY, _kb), &mut buf)?; - - let mut buf = [0_u8; crypto::BIGNUM_LEN_BYTES]; - let priv_key = psm.get_kv_slice(fb_key!(index, ST_PRKEY, _kb), &mut buf)?; - let keypair = KeyPair::new_from_components(pub_key, priv_key)?; - - let vendor_id = psm.get_kv_u64(fb_key!(index, ST_VID, _kb))?; - - Fabric::new(keypair, root_ca, icac, noc, ipk, vendor_id as u16, &label) - } - - fn remove(index: usize, mut psm: T) -> Result<(), Error> - where - T: Psm, - { - let mut _kb = heapless::String::<32>::new(); - - psm.remove(fb_key!(index, ST_RCA, _kb))?; - psm.remove(fb_key!(index, ST_ICA, _kb))?; - - psm.remove(fb_key!(index, ST_NOC, _kb))?; - - psm.remove(fb_key!(index, ST_LBL, _kb))?; - - psm.remove(fb_key!(index, ST_IPK, _kb))?; - - psm.remove(fb_key!(index, ST_PBKEY, _kb))?; - psm.remove(fb_key!(index, ST_PRKEY, _kb))?; - - psm.remove(fb_key!(index, ST_VID, _kb))?; - - Ok(()) - } - - #[cfg(feature = "nightly")] - async fn store_async(&self, index: usize, mut psm: T) -> Result<(), Error> - where - T: crate::persist::asynch::AsyncPsm, - { - let mut _kb = heapless::String::<32>::new(); - - psm.set_kv_slice(fb_key!(index, ST_RCA, _kb), &self.root_ca) - .await?; - - psm.set_kv_slice( - fb_key!(index, ST_ICA, _kb), - self.icac.as_deref().unwrap_or(&[]), - ) - .await?; - - psm.set_kv_slice(fb_key!(index, ST_NOC, _kb), &self.noc) - .await?; - psm.set_kv_slice(fb_key!(index, ST_IPK, _kb), self.ipk.epoch_key()) - .await?; - psm.set_kv_slice(fb_key!(index, ST_LBL, _kb), self.label.as_bytes()) - .await?; - - let mut buf = [0_u8; crypto::EC_POINT_LEN_BYTES]; - let len = self.key_pair.get_public_key(&mut buf)?; - let key = &buf[..len]; - psm.set_kv_slice(fb_key!(index, ST_PBKEY, _kb), key).await?; - - let mut buf = [0_u8; crypto::BIGNUM_LEN_BYTES]; - let len = self.key_pair.get_private_key(&mut buf)?; - let key = &buf[..len]; - psm.set_kv_slice(fb_key!(index, ST_PRKEY, _kb), key).await?; - - psm.set_kv_u64(fb_key!(index, ST_VID, _kb), self.vendor_id.into()) - .await?; - Ok(()) - } - - #[cfg(feature = "nightly")] - async fn load_async(index: usize, psm: T) -> Result - where - T: crate::persist::asynch::AsyncPsm, - { - let mut _kb = heapless::String::<32>::new(); - - let mut buf = [0u8; MAX_CERT_TLV_LEN]; - - let root_ca = heapless::Vec::from_slice( - psm.get_kv_slice(fb_key!(index, ST_RCA, _kb), &mut buf) - .await?, - ) - .unwrap(); - - let icac = psm - .get_kv_slice(fb_key!(index, ST_ICA, _kb), &mut buf) - .await?; - let icac = if !icac.is_empty() { - Some(heapless::Vec::from_slice(icac).unwrap()) - } else { - None - }; - - let noc = heapless::Vec::from_slice( - psm.get_kv_slice(fb_key!(index, ST_NOC, _kb), &mut buf) - .await?, - ) - .unwrap(); - - let label = psm - .get_kv_slice(fb_key!(index, ST_LBL, _kb), &mut buf) - .await?; - let label: heapless::String<32> = core::str::from_utf8(label) - .map_err(|_| { - error!("Couldn't read label"); - Error::Invalid - })? - .into(); - - let ipk = psm - .get_kv_slice(fb_key!(index, ST_IPK, _kb), &mut buf) - .await?; - - let mut buf = [0_u8; crypto::EC_POINT_LEN_BYTES]; - let pub_key = psm - .get_kv_slice(fb_key!(index, ST_PBKEY, _kb), &mut buf) - .await?; - - let mut buf = [0_u8; crypto::BIGNUM_LEN_BYTES]; - let priv_key = psm - .get_kv_slice(fb_key!(index, ST_PRKEY, _kb), &mut buf) - .await?; - let keypair = KeyPair::new_from_components(pub_key, priv_key)?; - - let vendor_id = psm.get_kv_u64(fb_key!(index, ST_VID, _kb)).await?; - - Fabric::new(keypair, root_ca, icac, noc, ipk, vendor_id as u16, &label) - } - - #[cfg(feature = "nightly")] - async fn remove_async(index: usize, mut psm: T) -> Result<(), Error> - where - T: crate::persist::asynch::AsyncPsm, - { - let mut _kb = heapless::String::<32>::new(); - - psm.remove(fb_key!(index, ST_RCA, _kb)).await?; - psm.remove(fb_key!(index, ST_ICA, _kb)).await?; - - psm.remove(fb_key!(index, ST_NOC, _kb)).await?; - - psm.remove(fb_key!(index, ST_LBL, _kb)).await?; - - psm.remove(fb_key!(index, ST_IPK, _kb)).await?; - - psm.remove(fb_key!(index, ST_PBKEY, _kb)).await?; - psm.remove(fb_key!(index, ST_PRKEY, _kb)).await?; - - psm.remove(fb_key!(index, ST_VID, _kb)).await?; - - Ok(()) - } } pub const MAX_SUPPORTED_FABRICS: usize = 3; +type FabricEntries = [Option; MAX_SUPPORTED_FABRICS]; + pub struct FabricMgr { - // The outside world expects Fabric Index to be one more than the actual one - // since 0 is not allowed. Need to handle this cleanly somehow - fabrics: [Option; MAX_SUPPORTED_FABRICS], + fabrics: FabricEntries, changed: bool, } @@ -440,41 +201,20 @@ impl FabricMgr { } } - pub fn store(&mut self, mut psm: T) -> Result<(), Error> - where - T: Psm, - { - if self.changed { - for i in 1..MAX_SUPPORTED_FABRICS { - if let Some(fabric) = self.fabrics[i].as_mut() { - info!("Storing fabric at index {}", i); - fabric.store(i, &mut psm)?; - } else { - let _ = Fabric::remove(i, &mut psm); - } + pub fn load(&mut self, data: &[u8], mdns_mgr: &mut MdnsMgr) -> Result<(), Error> { + for fabric in &self.fabrics { + if let Some(fabric) = fabric { + mdns_mgr.unpublish_service(&fabric.mdns_service_name, ServiceMode::Commissioned)?; } - - self.changed = false; } - Ok(()) - } + let root = TLVList::new(data).iter().next().ok_or(Error::Invalid)?; - pub fn load(&mut self, mut psm: T, mdns_mgr: &mut MdnsMgr) -> Result<(), Error> - where - T: Psm, - { - for i in 1..MAX_SUPPORTED_FABRICS { - let result = Fabric::load(i, &mut psm); - if let Ok(fabric) = result { - info!("Adding new fabric at index {}", i); - self.fabrics[i] = Some(fabric); - mdns_mgr.publish_service( - &self.fabrics[i].as_ref().unwrap().mdns_service_name, - ServiceMode::Commissioned, - )?; - } else { - self.fabrics[i] = None; + self.fabrics = FabricEntries::from_tlv(&root)?; + + for fabric in &self.fabrics { + if let Some(fabric) = fabric { + mdns_mgr.publish_service(&fabric.mdns_service_name, ServiceMode::Commissioned)?; } } @@ -483,67 +223,32 @@ impl FabricMgr { Ok(()) } - #[cfg(feature = "nightly")] - pub async fn store_async(&mut self, mut psm: T) -> Result<(), Error> - where - T: crate::persist::asynch::AsyncPsm, - { + pub fn store<'a>(&mut self, buf: &'a mut [u8]) -> Result, Error> { if self.changed { - for i in 1..MAX_SUPPORTED_FABRICS { - if let Some(fabric) = self.fabrics[i].as_mut() { - info!("Storing fabric at index {}", i); - fabric.store_async(i, &mut psm).await?; - } else { - let _ = Fabric::remove_async(i, &mut psm).await; - } - } + let mut wb = WriteBuf::new(buf); + let mut tw = TLVWriter::new(&mut wb); + self.fabrics.to_tlv(&mut tw, TagType::Anonymous)?; self.changed = false; - } - Ok(()) - } + let len = tw.get_tail(); - #[cfg(feature = "nightly")] - pub async fn load_async( - &mut self, - mut psm: T, - mdns_mgr: &mut MdnsMgr<'_>, - ) -> Result<(), Error> - where - T: crate::persist::asynch::AsyncPsm, - { - for i in 1..MAX_SUPPORTED_FABRICS { - let result = Fabric::load_async(i, &mut psm).await; - if let Ok(fabric) = result { - info!("Adding new fabric at index {}", i); - self.fabrics[i] = Some(fabric); - mdns_mgr.publish_service( - &self.fabrics[i].as_ref().unwrap().mdns_service_name, - ServiceMode::Commissioned, - )?; - } else { - self.fabrics[i] = None; - } + Ok(Some(&buf[..len])) + } else { + Ok(None) } - - self.changed = false; - - Ok(()) } pub fn add(&mut self, f: Fabric, mdns_mgr: &mut MdnsMgr) -> Result { - for i in 1..MAX_SUPPORTED_FABRICS { - if self.fabrics[i].is_none() { - self.fabrics[i] = Some(f); - mdns_mgr.publish_service( - &self.fabrics[i].as_ref().unwrap().mdns_service_name, - ServiceMode::Commissioned, - )?; + for (index, fabric) in self.fabrics.iter_mut().enumerate() { + if fabric.is_none() { + mdns_mgr.publish_service(&f.mdns_service_name, ServiceMode::Commissioned)?; + + *fabric = Some(f); self.changed = true; - return Ok(i as u8); + return Ok((index + 1) as u8); } } @@ -551,20 +256,24 @@ impl FabricMgr { } pub fn remove(&mut self, fab_idx: u8, mdns_mgr: &mut MdnsMgr) -> Result<(), Error> { - if let Some(f) = self.fabrics[fab_idx as usize].take() { - mdns_mgr.unpublish_service(&f.mdns_service_name, ServiceMode::Commissioned)?; - self.changed = true; - Ok(()) + if fab_idx > 0 && fab_idx as usize <= self.fabrics.len() { + if let Some(f) = self.fabrics[(fab_idx - 1) as usize].take() { + mdns_mgr.unpublish_service(&f.mdns_service_name, ServiceMode::Commissioned)?; + self.changed = true; + Ok(()) + } else { + Err(Error::NotFound) + } } else { Err(Error::NotFound) } } pub fn match_dest_id(&self, random: &[u8], target: &[u8]) -> Result { - for i in 1..MAX_SUPPORTED_FABRICS { - if let Some(fabric) = &self.fabrics[i] { + for (index, fabric) in self.fabrics.iter().enumerate() { + if let Some(fabric) = fabric { if fabric.match_dest_id(random, target).is_ok() { - return Ok(i); + return Ok(index + 1); } } } @@ -572,26 +281,19 @@ impl FabricMgr { } pub fn get_fabric(&self, idx: usize) -> Result, Error> { - Ok(self.fabrics[idx].as_ref()) + if idx == 0 { + Ok(None) + } else { + Ok(self.fabrics[idx - 1].as_ref()) + } } pub fn is_empty(&self) -> bool { - for i in 1..MAX_SUPPORTED_FABRICS { - if self.fabrics[i].is_some() { - return false; - } - } - true + !self.fabrics.iter().any(Option::is_some) } pub fn used_count(&self) -> usize { - let mut count = 0; - for i in 1..MAX_SUPPORTED_FABRICS { - if self.fabrics[i].is_some() { - count += 1; - } - } - count + self.fabrics.iter().filter(|f| f.is_some()).count() } // Parameters to T are the Fabric and its Fabric Index @@ -599,25 +301,27 @@ impl FabricMgr { where T: FnMut(&Fabric, u8) -> Result<(), Error>, { - for i in 1..MAX_SUPPORTED_FABRICS { - if let Some(fabric) = &self.fabrics[i] { - f(fabric, i as u8)?; + for (index, fabric) in self.fabrics.iter().enumerate() { + if let Some(fabric) = fabric { + f(fabric, (index + 1) as u8)?; } } Ok(()) } pub fn set_label(&mut self, index: u8, label: &str) -> Result<(), Error> { - let index = index as usize; if !label.is_empty() { - for i in 1..MAX_SUPPORTED_FABRICS { - if let Some(fabric) = &self.fabrics[i] { - if fabric.label == label { - return Err(Error::Invalid); - } - } + if self + .fabrics + .iter() + .filter_map(|f| f.as_ref()) + .any(|f| f.label == label) + { + return Err(Error::Invalid); } } + + let index = (index - 1) as usize; if let Some(fabric) = &mut self.fabrics[index] { fabric.label = label.into(); self.changed = true; diff --git a/matter/src/group_keys.rs b/matter/src/group_keys.rs index 1dc1c405..d4e97659 100644 --- a/matter/src/group_keys.rs +++ b/matter/src/group_keys.rs @@ -15,12 +15,18 @@ * limitations under the License. */ -use crate::{crypto, error::Error}; +use crate::{ + crypto::{self, SYMM_KEY_LEN_BYTES}, + error::Error, + tlv::{FromTLV, TLVElement, TLVWriter, TagType, ToTLV}, +}; -#[derive(Debug, Default)] +type KeySetKey = [u8; SYMM_KEY_LEN_BYTES]; + +#[derive(Debug, Default, FromTLV, ToTLV)] pub struct KeySet { - pub epoch_key: [u8; crypto::SYMM_KEY_LEN_BYTES], - pub op_key: [u8; crypto::SYMM_KEY_LEN_BYTES], + pub epoch_key: KeySetKey, + pub op_key: KeySetKey, } impl KeySet { diff --git a/matter/src/persist.rs b/matter/src/persist.rs index 4bc8e241..1ea494b5 100644 --- a/matter/src/persist.rs +++ b/matter/src/persist.rs @@ -14,216 +14,63 @@ * See the License for the specific language governing permissions and * limitations under the License. */ - -use crate::error::Error; - -pub trait Psm { - fn set_kv_slice(&mut self, key: &str, val: &[u8]) -> Result<(), Error>; - fn get_kv_slice<'a>(&self, key: &str, buf: &'a mut [u8]) -> Result<&'a [u8], Error>; - - fn set_kv_u64(&mut self, key: &str, val: u64) -> Result<(), Error>; - fn get_kv_u64(&self, key: &str) -> Result; - - fn remove(&mut self, key: &str) -> Result<(), Error>; -} - -impl Psm for &mut T -where - T: Psm, -{ - fn set_kv_slice(&mut self, key: &str, val: &[u8]) -> Result<(), Error> { - (**self).set_kv_slice(key, val) - } - - fn get_kv_slice<'a>(&self, key: &str, buf: &'a mut [u8]) -> Result<&'a [u8], Error> { - (**self).get_kv_slice(key, buf) - } - - fn set_kv_u64(&mut self, key: &str, val: u64) -> Result<(), Error> { - (**self).set_kv_u64(key, val) - } - - fn get_kv_u64(&self, key: &str) -> Result { - (**self).get_kv_u64(key) - } - - fn remove(&mut self, key: &str) -> Result<(), Error> { - (**self).remove(key) - } -} - -#[cfg(feature = "nightly")] -pub mod asynch { - use crate::error::Error; - - use super::Psm; - - pub trait AsyncPsm { - async fn set_kv_slice<'a>(&'a mut self, key: &'a str, val: &'a [u8]) -> Result<(), Error>; - async fn get_kv_slice<'a, 'b>( - &'a self, - key: &'a str, - buf: &'b mut [u8], - ) -> Result<&'b [u8], Error>; - - async fn set_kv_u64<'a>(&'a mut self, key: &'a str, val: u64) -> Result<(), Error>; - async fn get_kv_u64<'a>(&'a self, key: &'a str) -> Result; - - async fn remove<'a>(&'a mut self, key: &'a str) -> Result<(), Error>; - } - - impl AsyncPsm for &mut T - where - T: AsyncPsm, - { - async fn set_kv_slice<'a>(&'a mut self, key: &'a str, val: &'a [u8]) -> Result<(), Error> { - (**self).set_kv_slice(key, val).await - } - - async fn get_kv_slice<'a, 'b>( - &'a self, - key: &'a str, - buf: &'b mut [u8], - ) -> Result<&'b [u8], Error> { - (**self).get_kv_slice(key, buf).await - } - - async fn set_kv_u64<'a>(&'a mut self, key: &'a str, val: u64) -> Result<(), Error> { - (**self).set_kv_u64(key, val).await - } - - async fn get_kv_u64<'a>(&'a self, key: &'a str) -> Result { - (**self).get_kv_u64(key).await - } - - async fn remove<'a>(&'a mut self, key: &'a str) -> Result<(), Error> { - (**self).remove(key).await - } - } - - pub struct Asyncify(pub T); - - impl AsyncPsm for Asyncify - where - T: Psm, - { - async fn set_kv_slice<'a>(&'a mut self, key: &'a str, val: &'a [u8]) -> Result<(), Error> { - self.0.set_kv_slice(key, val) - } - - async fn get_kv_slice<'a, 'b>( - &'a self, - key: &'a str, - buf: &'b mut [u8], - ) -> Result<&'b [u8], Error> { - self.0.get_kv_slice(key, buf) - } - - async fn set_kv_u64<'a>(&'a mut self, key: &'a str, val: u64) -> Result<(), Error> { - self.0.set_kv_u64(key, val) - } - - async fn get_kv_u64<'a>(&'a self, key: &'a str) -> Result { - self.0.get_kv_u64(key) - } - - async fn remove<'a>(&'a mut self, key: &'a str) -> Result<(), Error> { - self.0.remove(key) - } - } -} +#[cfg(feature = "std")] +pub use file_psm::*; #[cfg(feature = "std")] -pub mod std { - use std::fs::{self, DirBuilder, File}; +mod file_psm { + use std::fs; use std::io::{Read, Write}; + use std::path::PathBuf; use crate::error::Error; - use super::Psm; - - pub struct FilePsm {} - - const PSM_DIR: &str = "/tmp/matter_psm"; - - macro_rules! psm_path { - ($key:ident) => { - format!("{}/{}", PSM_DIR, $key) - }; + pub struct FilePsm { + dir: PathBuf, } impl FilePsm { - pub fn new() -> Result { - let result = DirBuilder::new().create(PSM_DIR); - if let Err(e) = result { - if e.kind() != std::io::ErrorKind::AlreadyExists { - return Err(e.into()); - } - } - - Ok(Self {}) - } + pub fn new(dir: PathBuf) -> Result { + fs::create_dir_all(&dir)?; - pub fn set_kv_slice(&mut self, key: &str, val: &[u8]) -> Result<(), Error> { - let mut f = File::create(psm_path!(key))?; - f.write_all(val)?; - Ok(()) + Ok(Self { dir }) } - pub fn get_kv_slice<'a>(&self, key: &str, buf: &'a mut [u8]) -> Result<&'a [u8], Error> { - let mut f = File::open(psm_path!(key))?; - let mut offset = 0; + pub fn load<'a>(&self, key: &str, buf: &'a mut [u8]) -> Result, Error> { + let path = self.dir.join(key); - loop { - let len = f.read(&mut buf[offset..])?; - offset += len; + match fs::File::open(path) { + Ok(mut file) => { + let mut offset = 0; - if len == 0 { - break; - } - } + loop { + if offset == buf.len() { + return Err(Error::NoSpace); + } - Ok(&buf[..offset]) - } + let len = file.read(&mut buf[offset..])?; - pub fn set_kv_u64(&mut self, key: &str, val: u64) -> Result<(), Error> { - let mut f = File::create(psm_path!(key))?; - f.write_all(&val.to_be_bytes())?; - Ok(()) - } - - pub fn get_kv_u64(&self, key: &str) -> Result { - let mut f = File::open(psm_path!(key))?; - let mut buf = [0; 8]; - f.read_exact(&mut buf)?; - Ok(u64::from_be_bytes(buf)) - } + if len == 0 { + break; + } - pub fn remove(&self, key: &str) -> Result<(), Error> { - fs::remove_file(psm_path!(key))?; - Ok(()) - } - } + offset += len; + } - impl Psm for FilePsm { - fn set_kv_slice(&mut self, key: &str, val: &[u8]) -> Result<(), Error> { - FilePsm::set_kv_slice(self, key, val) + Ok(Some(&buf[..offset])) + } + Err(_) => Ok(None), + } } - fn get_kv_slice<'a>(&self, key: &str, buf: &'a mut [u8]) -> Result<&'a [u8], Error> { - FilePsm::get_kv_slice(self, key, buf) - } + pub fn store(&self, key: &str, data: &[u8]) -> Result<(), Error> { + let path = self.dir.join(key); - fn set_kv_u64(&mut self, key: &str, val: u64) -> Result<(), Error> { - FilePsm::set_kv_u64(self, key, val) - } + let mut file = fs::File::create(path)?; - fn get_kv_u64(&self, key: &str) -> Result { - FilePsm::get_kv_u64(self, key) - } + file.write_all(data)?; - fn remove(&mut self, key: &str) -> Result<(), Error> { - FilePsm::remove(self, key) + Ok(()) } } } diff --git a/matter/src/tlv/traits.rs b/matter/src/tlv/traits.rs index 72cfab23..100eb071 100644 --- a/matter/src/tlv/traits.rs +++ b/matter/src/tlv/traits.rs @@ -35,26 +35,21 @@ pub trait FromTLV<'a> { } } -impl<'a, T: Default + FromTLV<'a> + Copy, const N: usize> FromTLV<'a> for [T; N] { +impl<'a, T: FromTLV<'a>, const N: usize> FromTLV<'a> for [T; N] { fn from_tlv(t: &TLVElement<'a>) -> Result where Self: Sized, { t.confirm_array()?; - let mut a: [T; N] = [Default::default(); N]; - let mut index = 0; + + let mut a = heapless::Vec::::new(); if let Some(tlv_iter) = t.enter() { for element in tlv_iter { - if index < N { - a[index] = T::from_tlv(&element)?; - index += 1; - } else { - error!("Received TLV Array with elements larger than current size"); - break; - } + a.push(T::from_tlv(&element)?).map_err(|_| Error::NoSpace)?; } } - Ok(a) + + a.into_array().map_err(|_| Error::Invalid) } } @@ -114,6 +109,8 @@ totlv_for!(i8 u8 u16 u32 u64 bool); // // - UtfStr, OctetStr: These are versions that map to utfstr and ostr in the TLV spec // - These only have references into the original list +// - heapless::String, Vheapless::ec: Is the owned version of utfstr and ostr, data is cloned into this +// - heapless::String is only partially implemented // // - TLVArray: Is an array of entries, with reference within the original list @@ -165,6 +162,38 @@ impl<'a> ToTLV for OctetStr<'a> { } } +/// Implements the Owned version of Octet String +impl FromTLV<'_> for heapless::Vec { + fn from_tlv(t: &TLVElement) -> Result, Error> { + heapless::Vec::from_slice(t.slice()?).map_err(|_| Error::NoSpace) + } +} + +impl ToTLV for heapless::Vec { + fn to_tlv(&self, tw: &mut TLVWriter, tag: TagType) -> Result<(), Error> { + tw.str16(tag, self.as_slice()) + } +} + +/// Implements the Owned version of UTF String +impl FromTLV<'_> for heapless::String { + fn from_tlv(t: &TLVElement) -> Result, Error> { + let mut string = heapless::String::new(); + + string + .push_str(core::str::from_utf8(t.slice()?)?) + .map_err(|_| Error::NoSpace)?; + + Ok(string) + } +} + +impl ToTLV for heapless::String { + fn to_tlv(&self, tw: &mut TLVWriter, tag: TagType) -> Result<(), Error> { + tw.utf16(tag, self.as_bytes()) + } +} + /// Applies to all the Option<> Processing impl<'a, T: FromTLV<'a>> FromTLV<'a> for Option { fn from_tlv(t: &TLVElement<'a>) -> Result, Error> { @@ -259,7 +288,7 @@ impl<'a, T: ToTLV> TLVArray<'a, T> { } } -impl<'a, T: ToTLV + FromTLV<'a> + Copy> TLVArray<'a, T> { +impl<'a, T: ToTLV + FromTLV<'a> + Clone> TLVArray<'a, T> { pub fn get_index(&self, index: usize) -> T { for (curr, element) in self.iter().enumerate() { if curr == index { @@ -270,12 +299,12 @@ impl<'a, T: ToTLV + FromTLV<'a> + Copy> TLVArray<'a, T> { } } -impl<'a, T: FromTLV<'a> + Copy> Iterator for TLVArrayIter<'a, T> { +impl<'a, T: FromTLV<'a> + Clone> Iterator for TLVArrayIter<'a, T> { type Item = T; /* Code for going to the next Element */ fn next(&mut self) -> Option { match self { - Self::Slice(s_iter) => s_iter.next().copied(), + Self::Slice(s_iter) => s_iter.next().cloned(), Self::Ptr(p_iter) => { if let Some(tlv_iter) = p_iter.as_mut() { let e = tlv_iter.next(); @@ -294,7 +323,7 @@ impl<'a, T: FromTLV<'a> + Copy> Iterator for TLVArrayIter<'a, T> { impl<'a, T> PartialEq<&[T]> for TLVArray<'a, T> where - T: ToTLV + FromTLV<'a> + Copy + PartialEq, + T: ToTLV + FromTLV<'a> + Clone + PartialEq, { fn eq(&self, other: &&[T]) -> bool { let mut iter1 = self.iter(); @@ -313,7 +342,7 @@ where } } -impl<'a, T: FromTLV<'a> + Copy + ToTLV> ToTLV for TLVArray<'a, T> { +impl<'a, T: FromTLV<'a> + Clone + ToTLV> ToTLV for TLVArray<'a, T> { fn to_tlv(&self, tw: &mut TLVWriter, tag_type: TagType) -> Result<(), Error> { tw.start_array(tag_type)?; for a in self.iter() { @@ -340,7 +369,7 @@ impl<'a, T> FromTLV<'a> for TLVArray<'a, T> { } } -impl<'a, T: Debug + ToTLV + FromTLV<'a> + Copy> Debug for TLVArray<'a, T> { +impl<'a, T: Debug + ToTLV + FromTLV<'a> + Clone> Debug for TLVArray<'a, T> { fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { write!(f, "TLVArray [")?; let mut first = true; From 934ecb91659048ac06201da0c8255dd7d5736888 Mon Sep 17 00:00:00 2001 From: ivmarkov Date: Sat, 29 Apr 2023 15:58:21 +0000 Subject: [PATCH 32/72] Persistence bugfixing --- matter/src/crypto/mod.rs | 4 +++- matter/src/tlv/traits.rs | 9 ++++++++- 2 files changed, 11 insertions(+), 2 deletions(-) diff --git a/matter/src/crypto/mod.rs b/matter/src/crypto/mod.rs index 27ba187c..3b7b4c4e 100644 --- a/matter/src/crypto/mod.rs +++ b/matter/src/crypto/mod.rs @@ -100,7 +100,9 @@ impl ToTLV for KeyPair { tw.str16(TagType::Anonymous, &buf[..size])?; let size = self.get_private_key(&mut buf)?; - tw.str16(TagType::Anonymous, &buf[..size]) + tw.str16(TagType::Anonymous, &buf[..size])?; + + tw.end_container() } } diff --git a/matter/src/tlv/traits.rs b/matter/src/tlv/traits.rs index 100eb071..0311cb39 100644 --- a/matter/src/tlv/traits.rs +++ b/matter/src/tlv/traits.rs @@ -35,7 +35,7 @@ pub trait FromTLV<'a> { } } -impl<'a, T: FromTLV<'a>, const N: usize> FromTLV<'a> for [T; N] { +impl<'a, T: FromTLV<'a> + Default, const N: usize> FromTLV<'a> for [T; N] { fn from_tlv(t: &TLVElement<'a>) -> Result where Self: Sized, @@ -49,6 +49,13 @@ impl<'a, T: FromTLV<'a>, const N: usize> FromTLV<'a> for [T; N] { } } + // TODO: This was the old behavior before rebasing the + // implementation on top of heapless::Vec (to avoid requiring Copy) + // Not sure why we actually need that yet, but without it unit tests fail + while a.len() < N { + a.push(Default::default()).map_err(|_| Error::NoSpace)?; + } + a.into_array().map_err(|_| Error::Invalid) } } From 3dccc0d710ab1544b71fe05f2c053857cd2e1181 Mon Sep 17 00:00:00 2001 From: ivmarkov Date: Sat, 29 Apr 2023 16:01:21 +0000 Subject: [PATCH 33/72] Persistence - trace info --- matter/src/persist.rs | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/matter/src/persist.rs b/matter/src/persist.rs index 1ea494b5..53e413eb 100644 --- a/matter/src/persist.rs +++ b/matter/src/persist.rs @@ -23,6 +23,8 @@ mod file_psm { use std::io::{Read, Write}; use std::path::PathBuf; + use log::info; + use crate::error::Error; pub struct FilePsm { @@ -57,7 +59,11 @@ mod file_psm { offset += len; } - Ok(Some(&buf[..offset])) + let data = &buf[..offset]; + + info!("Key {}: loaded {} bytes {:?}", key, data.len(), data); + + Ok(Some(data)) } Err(_) => Ok(None), } @@ -70,6 +76,8 @@ mod file_psm { file.write_all(data)?; + info!("Key {}: stored {} bytes {:?}", key, data.len(), data); + Ok(()) } } From 974ac4d1d80c28ec5473e25abe25f691e27a13a2 Mon Sep 17 00:00:00 2001 From: imarkov Date: Sat, 29 Apr 2023 19:38:01 +0300 Subject: [PATCH 34/72] Optional feature to capture stacktrace on error --- examples/onoff_light/src/dev_att.rs | 4 +- examples/onoff_light/src/main.rs | 73 ++++++++++------- examples/speaker/src/dev_att.rs | 2 +- matter/Cargo.toml | 5 +- matter/src/acl.rs | 18 ++--- matter/src/cert/asn1_writer.rs | 12 +-- matter/src/cert/mod.rs | 51 +++++++----- matter/src/codec/base38.rs | 16 ++-- matter/src/crypto/crypto_dummy.rs | 10 +-- matter/src/crypto/crypto_esp_mbedtls.rs | 10 +-- matter/src/crypto/crypto_mbedtls.rs | 34 ++++---- matter/src/crypto/crypto_openssl.rs | 23 +++--- matter/src/crypto/crypto_rustcrypto.rs | 14 ++-- matter/src/crypto/mod.rs | 15 ++-- matter/src/data_model/cluster_template.rs | 4 +- matter/src/data_model/objects/cluster.rs | 4 +- matter/src/data_model/objects/encoder.rs | 27 +++++-- matter/src/data_model/objects/handler.rs | 20 +++-- matter/src/data_model/objects/privilege.rs | 6 +- .../src/data_model/sdm/admin_commissioning.rs | 2 +- matter/src/data_model/sdm/failsafe.rs | 19 +++-- .../data_model/sdm/general_commissioning.rs | 4 +- matter/src/data_model/sdm/noc.rs | 12 +-- matter/src/data_model/sdm/nw_commissioning.rs | 4 +- .../data_model/system_model/access_control.rs | 8 +- matter/src/error.rs | 80 +++++++++++++++---- matter/src/fabric.rs | 21 ++--- matter/src/group_keys.rs | 7 +- matter/src/interaction_model/core.rs | 36 +++++---- matter/src/interaction_model/messages.rs | 14 ++-- matter/src/mdns.rs | 10 +-- matter/src/pairing/qr.rs | 19 ++--- matter/src/persist.rs | 4 +- matter/src/secure_channel/case.rs | 30 +++---- matter/src/secure_channel/core.rs | 4 +- matter/src/secure_channel/crypto_dummy.rs | 18 ++--- matter/src/secure_channel/crypto_mbedtls.rs | 4 +- matter/src/secure_channel/crypto_openssl.rs | 4 +- matter/src/secure_channel/pake.rs | 14 ++-- matter/src/secure_channel/spake2p.rs | 8 +- matter/src/secure_channel/status_report.rs | 1 + matter/src/tlv/parser.rs | 50 +++++++----- matter/src/tlv/traits.rs | 21 ++--- matter/src/tlv/writer.rs | 2 +- matter/src/transport/exchange.rs | 32 +++++--- matter/src/transport/mgr.rs | 14 ++-- matter/src/transport/mrp.rs | 8 +- matter/src/transport/packet.rs | 16 ++-- matter/src/transport/plain_hdr.rs | 2 +- matter/src/transport/proto_hdr.rs | 8 +- matter/src/transport/session.rs | 10 +-- matter/src/transport/udp.rs | 4 +- matter/src/utils/parsebuf.rs | 4 +- matter/src/utils/writebuf.rs | 12 +-- matter/tests/common/echo_cluster.rs | 10 +-- matter/tests/common/im_engine.rs | 2 +- matter_macro_derive/Cargo.toml | 1 + matter_macro_derive/src/lib.rs | 48 +++++++---- 58 files changed, 531 insertions(+), 384 deletions(-) diff --git a/examples/onoff_light/src/dev_att.rs b/examples/onoff_light/src/dev_att.rs index a16d53f1..93fcbd3d 100644 --- a/examples/onoff_light/src/dev_att.rs +++ b/examples/onoff_light/src/dev_att.rs @@ -16,7 +16,7 @@ */ use matter::data_model::sdm::dev_att::{DataType, DevAttDataFetcher}; -use matter::error::Error; +use matter::error::{Error, ErrorCode}; pub struct HardCodedDevAtt {} @@ -159,7 +159,7 @@ impl DevAttDataFetcher for HardCodedDevAtt { data.copy_from_slice(src); Ok(src.len()) } else { - Err(Error::NoSpace) + Err(ErrorCode::NoSpace.into()) } } } diff --git a/examples/onoff_light/src/main.rs b/examples/onoff_light/src/main.rs index b6d2588b..604ffce3 100644 --- a/examples/onoff_light/src/main.rs +++ b/examples/onoff_light/src/main.rs @@ -16,7 +16,9 @@ */ use std::borrow::Borrow; +use std::error::Error; +use log::info; use matter::core::{CommissioningData, Matter}; use matter::data_model::cluster_basic_information::BasicInfoConfig; use matter::data_model::cluster_on_off; @@ -36,8 +38,10 @@ use matter::transport::{ mod dev_att; -fn main() { - env_logger::init(); +fn main() -> Result<(), impl Error> { + env_logger::init_from_env( + env_logger::Env::default().filter_or(env_logger::DEFAULT_FILTER_ENV, "info"), + ); // vid/pid should match those in the DAC let dev_info = BasicInfoConfig { @@ -50,35 +54,37 @@ fn main() { device_name: "OnOff Light", }; - //let mut mdns = matter::mdns::astro::AstroMdns::new().unwrap(); - let mut mdns = matter::mdns::libmdns::LibMdns::new().unwrap(); + let mut mdns = matter::mdns::astro::AstroMdns::new()?; + //let mut mdns = matter::mdns::libmdns::LibMdns::new()?; + //let mut mdns = matter::mdns::DummyMdns {}; let matter = Matter::new_default(&dev_info, &mut mdns, matter::transport::udp::MATTER_PORT); let dev_att = dev_att::HardCodedDevAtt::new(); - let psm = persist::FilePsm::new(std::env::temp_dir().join("matter-iot")).unwrap(); + let psm_path = std::env::temp_dir().join("matter-iot"); + info!("Persisting from/to {}", psm_path.display()); + + let psm = persist::FilePsm::new(psm_path)?; let mut buf = [0; 4096]; - if let Some(data) = psm.load("fabrics", &mut buf).unwrap() { - matter.load_fabrics(data).unwrap(); + if let Some(data) = psm.load("acls", &mut buf)? { + matter.load_acls(data)?; } - if let Some(data) = psm.load("acls", &mut buf).unwrap() { - matter.load_acls(data).unwrap(); + if let Some(data) = psm.load("fabrics", &mut buf)? { + matter.load_fabrics(data)?; } - matter - .start::<4096>( - CommissioningData { - // TODO: Hard-coded for now - verifier: VerifierData::new_with_pw(123456, *matter.borrow()), - discriminator: 250, - }, - &mut buf, - ) - .unwrap(); + matter.start::<4096>( + CommissioningData { + // TODO: Hard-coded for now + verifier: VerifierData::new_with_pw(123456, *matter.borrow()), + discriminator: 250, + }, + &mut buf, + )?; let matter = &matter; let dev_att = &dev_att; @@ -86,20 +92,20 @@ fn main() { let mut transport = TransportMgr::new(matter); smol::block_on(async move { - let udp = UdpListener::new().await.unwrap(); + let udp = UdpListener::new().await?; loop { let mut rx_buf = [0; MAX_RX_BUF_SIZE]; let mut tx_buf = [0; MAX_TX_BUF_SIZE]; - let (len, addr) = udp.recv(&mut rx_buf).await.unwrap(); + let (len, addr) = udp.recv(&mut rx_buf).await?; let mut completion = transport.recv(addr, &mut rx_buf[..len], &mut tx_buf); - while let Some(action) = completion.next_action().unwrap() { + while let Some(action) = completion.next_action()? { match action { RecvAction::Send(addr, buf) => { - udp.send(addr, buf).await.unwrap(); + udp.send(addr, buf).await?; } RecvAction::Interact(mut ctx) => { let node = Node { @@ -119,24 +125,29 @@ fn main() { let mut im = InteractionModel(DataModel::new(matter.borrow(), &node, &mut handler)); - if im.handle(&mut ctx).unwrap() { - if ctx.send().unwrap() { - udp.send(ctx.tx.peer, ctx.tx.as_slice()).await.unwrap(); + if im.handle(&mut ctx)? { + if ctx.send()? { + udp.send(ctx.tx.peer, ctx.tx.as_slice()).await?; } } } } } - if let Some(data) = matter.store_fabrics(&mut buf).unwrap() { - psm.store("fabrics", data).unwrap(); + if let Some(data) = matter.store_fabrics(&mut buf)? { + psm.store("fabrics", data)?; } - if let Some(data) = matter.store_acls(&mut buf).unwrap() { - psm.store("acls", data).unwrap(); + if let Some(data) = matter.store_acls(&mut buf)? { + psm.store("acls", data)?; } } - }); + + #[allow(unreachable_code)] + Ok::<_, matter::error::Error>(()) + })?; + + Ok::<_, matter::error::Error>(()) } fn handler<'a>(matter: &'a Matter<'a>, dev_att: &'a dyn DevAttDataFetcher) -> impl Handler + 'a { diff --git a/examples/speaker/src/dev_att.rs b/examples/speaker/src/dev_att.rs index a16d53f1..c0c10306 100644 --- a/examples/speaker/src/dev_att.rs +++ b/examples/speaker/src/dev_att.rs @@ -159,7 +159,7 @@ impl DevAttDataFetcher for HardCodedDevAtt { data.copy_from_slice(src); Ok(src.len()) } else { - Err(Error::NoSpace) + Err(ErrorCode::NoSpace.into()) } } } diff --git a/matter/Cargo.toml b/matter/Cargo.toml index 9f5503b3..78e09932 100644 --- a/matter/Cargo.toml +++ b/matter/Cargo.toml @@ -15,8 +15,9 @@ name = "matter" path = "src/lib.rs" [features] -default = ["std", "crypto_mbedtls"] +default = ["std", "crypto_mbedtls", "backtrace"] std = ["alloc", "env_logger", "chrono", "rand", "qrcode", "libmdns", "simple-mdns", "simple-dns", "smol"] +backtrace = [] alloc = [] nightly = [] crypto_openssl = ["openssl", "foreign-types", "hmac", "sha2"] @@ -32,7 +33,7 @@ heapless = "0.7.16" num = "0.4" num-derive = "0.3.3" num-traits = "0.2.15" -strum = { version = "0.24", features = ["derive"], default-features = false, no-default-feature = true } +strum = { version = "0.24", features = ["derive"], default-features = false } log = { version = "0.4.17", features = ["max_level_debug", "release_max_level_debug"] } no-std-net = "0.6" subtle = "2.4.1" diff --git a/matter/src/acl.rs b/matter/src/acl.rs index dea592b3..77b8e5be 100644 --- a/matter/src/acl.rs +++ b/matter/src/acl.rs @@ -19,7 +19,7 @@ use core::{cell::RefCell, fmt::Display}; use crate::{ data_model::objects::{Access, ClusterId, EndptId, Privilege}, - error::Error, + error::{Error, ErrorCode}, fabric, interaction_model::messages::GenericPath, tlv::{FromTLV, TLVElement, TLVList, TLVWriter, TagType, ToTLV}, @@ -50,7 +50,7 @@ impl FromTLV<'_> for AuthMode { { num::FromPrimitive::from_u32(t.u32()?) .filter(|a| *a != AuthMode::Invalid) - .ok_or(Error::Invalid) + .ok_or_else(|| ErrorCode::Invalid.into()) } } @@ -112,7 +112,7 @@ impl AccessorSubjects { return Ok(()); } } - Err(Error::NoSpace) + Err(ErrorCode::NoSpace.into()) } /// Match the match_subject with any of the current subjects @@ -314,7 +314,7 @@ impl AclEntry { .subjects .iter() .position(|s| s.is_none()) - .ok_or(Error::NoSpace)?; + .ok_or(ErrorCode::NoSpace)?; self.subjects[index] = Some(subject); Ok(()) } @@ -328,7 +328,7 @@ impl AclEntry { .targets .iter() .position(|s| s.is_none()) - .ok_or(Error::NoSpace)?; + .ok_or(ErrorCode::NoSpace)?; self.targets[index] = Some(target); Ok(()) } @@ -425,13 +425,13 @@ impl AclMgr { .filter(|a| a.fab_idx == entry.fab_idx) .count(); if cnt >= ENTRIES_PER_FABRIC { - return Err(Error::NoSpace); + Err(ErrorCode::NoSpace)?; } let index = self .entries .iter() .position(|a| a.is_none()) - .ok_or(Error::NoSpace)?; + .ok_or(ErrorCode::NoSpace)?; self.entries[index] = Some(entry); self.changed = true; @@ -503,7 +503,7 @@ impl AclMgr { } pub fn load(&mut self, data: &[u8]) -> Result<(), Error> { - let root = TLVList::new(data).iter().next().ok_or(Error::Invalid)?; + let root = TLVList::new(data).iter().next().ok_or(ErrorCode::Invalid)?; self.entries = AclEntries::from_tlv(&root)?; self.changed = false; @@ -547,7 +547,7 @@ impl AclMgr { return Ok(entry); } } - Err(Error::NotFound) + Err(ErrorCode::NotFound.into()) } } diff --git a/matter/src/cert/asn1_writer.rs b/matter/src/cert/asn1_writer.rs index b6f4ab78..4afd6b6c 100644 --- a/matter/src/cert/asn1_writer.rs +++ b/matter/src/cert/asn1_writer.rs @@ -17,7 +17,7 @@ use super::{CertConsumer, MAX_DEPTH}; use crate::{ - error::Error, + error::{Error, ErrorCode}, utils::epoch::{UtcCalendar, MATTER_EPOCH_SECS}, }; use core::{fmt::Write, time::Duration}; @@ -54,7 +54,7 @@ impl<'a> ASN1Writer<'a> { self.offset += size; return Ok(()); } - Err(Error::NoSpace) + Err(ErrorCode::NoSpace.into()) } pub fn append_tlv(&mut self, tag: u8, len: usize, f: F) -> Result<(), Error> @@ -70,7 +70,7 @@ impl<'a> ASN1Writer<'a> { self.offset += len; return Ok(()); } - Err(Error::NoSpace) + Err(ErrorCode::NoSpace.into()) } fn add_compound(&mut self, val: u8) -> Result<(), Error> { @@ -80,7 +80,7 @@ impl<'a> ASN1Writer<'a> { self.depth[self.current_depth] = self.offset; self.current_depth += 1; if self.current_depth >= MAX_DEPTH { - Err(Error::NoSpace) + Err(ErrorCode::NoSpace.into()) } else { Ok(()) } @@ -113,7 +113,7 @@ impl<'a> ASN1Writer<'a> { fn end_compound(&mut self) -> Result<(), Error> { if self.current_depth == 0 { - return Err(Error::Invalid); + Err(ErrorCode::Invalid)?; } let seq_len = self.get_compound_len(); let write_offset = self.get_length_encoding_offset(); @@ -148,7 +148,7 @@ impl<'a> ASN1Writer<'a> { // This is done with an 0xA2 followed by 2 bytes of actual len 3 } else { - return Err(Error::NoSpace); + Err(ErrorCode::NoSpace)? }; Ok(len) } diff --git a/matter/src/cert/mod.rs b/matter/src/cert/mod.rs index 621b28d4..d750db56 100644 --- a/matter/src/cert/mod.rs +++ b/matter/src/cert/mod.rs @@ -19,7 +19,7 @@ use core::fmt::{self, Write}; use crate::{ crypto::KeyPair, - error::Error, + error::{Error, ErrorCode}, tlv::{self, FromTLV, OctetStr, TLVArray, TLVElement, TLVWriter, TagType, ToTLV}, utils::{epoch::UtcCalendar, writebuf::WriteBuf}, }; @@ -349,22 +349,22 @@ impl<'a> FromTLV<'a> for DistNames<'a> { let mut d = Self { dn: heapless::Vec::new(), }; - let iter = t.confirm_list()?.enter().ok_or(Error::Invalid)?; + let iter = t.confirm_list()?.enter().ok_or(ErrorCode::Invalid)?; for t in iter { if let TagType::Context(tag) = t.get_tag() { if let Ok(value) = t.u64() { d.dn.push((tag, DistNameValue::Uint(value))) - .map_err(|_| Error::BufferTooSmall)?; + .map_err(|_| ErrorCode::BufferTooSmall)?; } else if let Ok(value) = t.slice() { if tag > PRINTABLE_STR_THRESHOLD { d.dn.push(( tag - PRINTABLE_STR_THRESHOLD, DistNameValue::PrintableStr(value), )) - .map_err(|_| Error::BufferTooSmall)?; + .map_err(|_| ErrorCode::BufferTooSmall)?; } else { d.dn.push((tag, DistNameValue::Utf8Str(value))) - .map_err(|_| Error::BufferTooSmall)?; + .map_err(|_| ErrorCode::BufferTooSmall)?; } } } @@ -531,7 +531,7 @@ fn encode_dn_value( } _ => { error!("Invalid encoding"); - return Err(Error::Invalid); + Err(ErrorCode::Invalid)? } }, DistNameValue::Utf8Str(v) => { @@ -570,7 +570,9 @@ impl<'a> Cert<'a> { } pub fn get_node_id(&self) -> Result { - self.subject.u64(DnTags::NodeId).ok_or(Error::NoNodeId) + self.subject + .u64(DnTags::NodeId) + .ok_or_else(|| Error::from(ErrorCode::NoNodeId)) } pub fn get_cat_ids(&self, output: &mut [u32]) { @@ -578,7 +580,9 @@ impl<'a> Cert<'a> { } pub fn get_fabric_id(&self) -> Result { - self.subject.u64(DnTags::FabricId).ok_or(Error::NoFabricId) + self.subject + .u64(DnTags::FabricId) + .ok_or_else(|| Error::from(ErrorCode::NoFabricId)) } pub fn get_pubkey(&self) -> &[u8] { @@ -589,7 +593,7 @@ impl<'a> Cert<'a> { if let Some(id) = self.extensions.subj_key_id.as_ref() { Ok(id.0) } else { - Err(Error::Invalid) + Err(ErrorCode::Invalid.into()) } } @@ -641,7 +645,7 @@ impl<'a> Cert<'a> { w.integer("Serial Num:", self.serial_no.0)?; w.start_seq("Signature Algorithm:")?; - let (str, oid) = match get_sign_algo(self.sign_algo).ok_or(Error::Invalid)? { + let (str, oid) = match get_sign_algo(self.sign_algo).ok_or(ErrorCode::Invalid)? { SignAlgoValue::ECDSAWithSHA256 => ("ECDSA with SHA256", OID_ECDSA_WITH_SHA256), }; w.oid(str, &oid)?; @@ -660,11 +664,11 @@ impl<'a> Cert<'a> { w.start_seq("")?; w.start_seq("Public Key Algorithm")?; - let (str, pub_key) = match get_pubkey_algo(self.pubkey_algo).ok_or(Error::Invalid)? { + let (str, pub_key) = match get_pubkey_algo(self.pubkey_algo).ok_or(ErrorCode::Invalid)? { PubKeyAlgoValue::EcPubKey => ("ECPubKey", OID_PUB_KEY_ECPUBKEY), }; w.oid(str, &pub_key)?; - let (str, curve_id) = match get_ec_curve_id(self.ec_curve_id).ok_or(Error::Invalid)? { + let (str, curve_id) = match get_ec_curve_id(self.ec_curve_id).ok_or(ErrorCode::Invalid)? { EcCurveIdValue::Prime256V1 => ("Prime256v1", OID_EC_TYPE_PRIME256V1), }; w.oid(str, &curve_id)?; @@ -704,7 +708,7 @@ impl<'a> CertVerifier<'a> { pub fn add_cert(self, parent: &'a Cert) -> Result, Error> { if !self.cert.is_authority(parent)? { - return Err(Error::InvalidAuthKey); + Err(ErrorCode::InvalidAuthKey)?; } let mut asn1 = [0u8; MAX_ASN1_CERT_SIZE]; let len = self.cert.as_asn1(&mut asn1, self.utc_calendar)?; @@ -761,7 +765,6 @@ mod tests { use log::info; use crate::cert::Cert; - use crate::error::Error; use crate::tlv::{self, FromTLV, TLVWriter, TagType, ToTLV}; use crate::utils::writebuf::WriteBuf; @@ -815,31 +818,43 @@ mod tests { #[test] fn test_verify_chain_incomplete() { // The chain doesn't lead up to a self-signed certificate + + use crate::error::ErrorCode; let noc = Cert::new(&test_vectors::NOC1_SUCCESS).unwrap(); let icac = Cert::new(&test_vectors::ICAC1_SUCCESS).unwrap(); let a = noc.verify_chain_start(crate::utils::epoch::sys_utc_calendar); assert_eq!( - Err(Error::InvalidAuthKey), - a.add_cert(&icac).unwrap().finalise() + Err(ErrorCode::InvalidAuthKey), + a.add_cert(&icac).unwrap().finalise().map_err(|e| e.code()) ); } #[cfg(feature = "std")] #[test] fn test_auth_key_chain_incorrect() { + use crate::error::ErrorCode; + let noc = Cert::new(&test_vectors::NOC1_AUTH_KEY_FAIL).unwrap(); let icac = Cert::new(&test_vectors::ICAC1_SUCCESS).unwrap(); let a = noc.verify_chain_start(crate::utils::epoch::sys_utc_calendar); - assert_eq!(Err(Error::InvalidAuthKey), a.add_cert(&icac).map(|_| ())); + assert_eq!( + Err(ErrorCode::InvalidAuthKey), + a.add_cert(&icac).map(|_| ()).map_err(|e| e.code()) + ); } #[cfg(feature = "std")] #[test] fn test_cert_corrupted() { + use crate::error::ErrorCode; + let noc = Cert::new(&test_vectors::NOC1_CORRUPT_CERT).unwrap(); let icac = Cert::new(&test_vectors::ICAC1_SUCCESS).unwrap(); let a = noc.verify_chain_start(crate::utils::epoch::sys_utc_calendar); - assert_eq!(Err(Error::InvalidSignature), a.add_cert(&icac).map(|_| ())); + assert_eq!( + Err(ErrorCode::InvalidSignature), + a.add_cert(&icac).map(|_| ()).map_err(|e| e.code()) + ); } #[test] diff --git a/matter/src/codec/base38.rs b/matter/src/codec/base38.rs index 14114e63..954f8626 100644 --- a/matter/src/codec/base38.rs +++ b/matter/src/codec/base38.rs @@ -17,7 +17,7 @@ //! Base38 encoding and decoding functions. -use crate::error::Error; +use crate::error::{Error, ErrorCode}; const BASE38_CHARS: [char; 38] = [ '0', '1', '2', '3', '4', '5', '6', '7', '8', '9', 'A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'I', @@ -86,7 +86,7 @@ const RADIX: u32 = BASE38_CHARS.len() as u32; pub fn encode_string(bytes: &[u8]) -> Result, Error> { let mut string = heapless::String::new(); for c in encode(bytes) { - string.push(c).map_err(|_| Error::NoSpace)?; + string.push(c).map_err(|_| ErrorCode::NoSpace)?; } Ok(string) @@ -135,7 +135,7 @@ pub fn decode_vec(base38_str: &str) -> Result impl Iterator> { match decode_char(*c) { Ok(v) => value = value * RADIX + v as u32, Err(err) => { - cerr = Some(err); + cerr = Some(err.code()); break; } } } } else { - cerr = Some(Error::InvalidData) + cerr = Some(ErrorCode::InvalidData) } (0..repeat) .map(move |_| { if let Some(err) = cerr { - Err(err) + Err(err.into()) } else { let byte = (value & 0xff) as u8; @@ -205,12 +205,12 @@ fn decode_base38(chars: &[u8]) -> impl Iterator> { fn decode_char(c: u8) -> Result { if !(45..=90).contains(&c) { - return Err(Error::InvalidData); + Err(ErrorCode::InvalidData)?; } let c = DECODE_BASE38[c as usize - 45]; if c == UNUSED { - return Err(Error::InvalidData); + Err(ErrorCode::InvalidData)?; } Ok(c) diff --git a/matter/src/crypto/crypto_dummy.rs b/matter/src/crypto/crypto_dummy.rs index acdae098..f00cefd8 100644 --- a/matter/src/crypto/crypto_dummy.rs +++ b/matter/src/crypto/crypto_dummy.rs @@ -17,7 +17,7 @@ use log::error; -use crate::error::Error; +use crate::error::{Error, ErrorCode}; pub fn hkdf_sha256(_salt: &[u8], _ikm: &[u8], _info: &[u8], _key: &mut [u8]) -> Result<(), Error> { error!("This API should never get called"); @@ -79,7 +79,7 @@ impl KeyPair { pub fn get_csr<'a>(&self, _out_csr: &'a mut [u8]) -> Result<&'a [u8], Error> { error!("This API should never get called"); - Err(Error::Invalid) + Err(ErrorCode::Invalid.into()) } pub fn get_public_key(&self, _pub_key: &mut [u8]) -> Result { @@ -92,17 +92,17 @@ impl KeyPair { pub fn derive_secret(self, _peer_pub_key: &[u8], _secret: &mut [u8]) -> Result { error!("This API should never get called"); - Err(Error::Invalid) + Err(ErrorCode::Invalid.into()) } pub fn sign_msg(&self, _msg: &[u8], _signature: &mut [u8]) -> Result { error!("This API should never get called"); - Err(Error::Invalid) + Err(ErrorCode::Invalid.into()) } pub fn verify_msg(&self, _msg: &[u8], _signature: &[u8]) -> Result<(), Error> { error!("This API should never get called"); - Err(Error::Invalid) + Err(ErrorCode::Invalid.into()) } } diff --git a/matter/src/crypto/crypto_esp_mbedtls.rs b/matter/src/crypto/crypto_esp_mbedtls.rs index 4eee8a76..cad046ba 100644 --- a/matter/src/crypto/crypto_esp_mbedtls.rs +++ b/matter/src/crypto/crypto_esp_mbedtls.rs @@ -17,7 +17,7 @@ use log::error; -use crate::error::Error; +use crate::error::{Error, ErrorCode}; pub fn hkdf_sha256(_salt: &[u8], _ikm: &[u8], _info: &[u8], _key: &mut [u8]) -> Result<(), Error> { error!("This API should never get called"); @@ -81,7 +81,7 @@ impl KeyPair { pub fn get_csr<'a>(&self, _out_csr: &'a mut [u8]) -> Result<&'a [u8], Error> { error!("This API should never get called"); - Err(Error::Invalid) + Err(ErrorCode::Invalid.into()) } pub fn get_public_key(&self, _pub_key: &mut [u8]) -> Result { @@ -94,17 +94,17 @@ impl KeyPair { pub fn derive_secret(self, _peer_pub_key: &[u8], _secret: &mut [u8]) -> Result { error!("This API should never get called"); - Err(Error::Invalid) + Err(ErrorCode::Invalid.into()) } pub fn sign_msg(&self, _msg: &[u8], _signature: &mut [u8]) -> Result { error!("This API should never get called"); - Err(Error::Invalid) + Err(ErrorCode::Invalid.into()) } pub fn verify_msg(&self, _msg: &[u8], _signature: &[u8]) -> Result<(), Error> { error!("This API should never get called"); - Err(Error::Invalid) + Err(ErrorCode::Invalid.into()) } } diff --git a/matter/src/crypto/crypto_mbedtls.rs b/matter/src/crypto/crypto_mbedtls.rs index c87e669a..3f95d04f 100644 --- a/matter/src/crypto/crypto_mbedtls.rs +++ b/matter/src/crypto/crypto_mbedtls.rs @@ -34,7 +34,7 @@ use crate::{ // TODO: We should move ASN1Writer out of Cert, // so Crypto doesn't have to depend on Cert cert::{ASN1Writer, CertConsumer}, - error::Error, + error::{Error, ErrorCode}, }; pub struct HmacSha256 { @@ -49,11 +49,13 @@ impl HmacSha256 { } pub fn update(&mut self, data: &[u8]) -> Result<(), Error> { - self.inner.update(data).map_err(|_| Error::TLSStack) + self.inner + .update(data) + .map_err(|_| ErrorCode::TLSStack.into()) } pub fn finish(self, out: &mut [u8]) -> Result<(), Error> { - self.inner.finish(out).map_err(|_| Error::TLSStack)?; + self.inner.finish(out).map_err(|_| ErrorCode::TLSStack)?; Ok(()) } } @@ -102,11 +104,11 @@ impl KeyPair { Ok(Some(a)) => Ok(a), Ok(None) => { error!("Error in writing CSR: None received"); - Err(Error::Invalid) + Err(ErrorCode::Invalid.into()) } Err(e) => { error!("Error in writing CSR {}", e); - Err(Error::TLSStack) + Err(ErrorCode::TLSStack.into()) } } } @@ -161,7 +163,7 @@ impl KeyPair { let mut ctr_drbg = CtrDrbg::new(Arc::new(OsEntropy::new()), None)?; if signature.len() < super::EC_SIGNATURE_LEN_BYTES { - return Err(Error::NoSpace); + Err(ErrorCode::NoSpace)?; } safemem::write_bytes(signature, 0); @@ -192,7 +194,7 @@ impl KeyPair { if let Err(e) = tmp_key.verify(hash::Type::Sha256, &msg_hash, mbedtls_sign) { info!("The error is {}", e); - Err(Error::InvalidSignature) + Err(ErrorCode::InvalidSignature.into()) } else { Ok(()) } @@ -229,7 +231,7 @@ fn convert_asn1_sign_to_r_s(signature: &mut [u8]) -> Result { // Type 0x2 is Integer (first integer is r) if signature[offset] != 2 { - return Err(Error::Invalid); + Err(ErrorCode::Invalid)?; } offset += 1; @@ -254,7 +256,7 @@ fn convert_asn1_sign_to_r_s(signature: &mut [u8]) -> Result { // Type 0x2 is Integer (this integer is s) if signature[offset] != 2 { - return Err(Error::Invalid); + Err(ErrorCode::Invalid)?; } offset += 1; @@ -273,17 +275,17 @@ fn convert_asn1_sign_to_r_s(signature: &mut [u8]) -> Result { Ok(64) } else { - Err(Error::Invalid) + Err(ErrorCode::Invalid.into()) } } pub fn pbkdf2_hmac(pass: &[u8], iter: usize, salt: &[u8], key: &mut [u8]) -> Result<(), Error> { mbedtls::hash::pbkdf2_hmac(Type::Sha256, pass, salt, iter as u32, key) - .map_err(|_e| Error::TLSStack) + .map_err(|_e| ErrorCode::TLSStack.into()) } pub fn hkdf_sha256(salt: &[u8], ikm: &[u8], info: &[u8], key: &mut [u8]) -> Result<(), Error> { - Hkdf::hkdf(Type::Sha256, salt, ikm, info, key).map_err(|_e| Error::TLSStack) + Hkdf::hkdf(Type::Sha256, salt, ikm, info, key).map_err(|_e| ErrorCode::TLSStack.into()) } pub fn encrypt_in_place( @@ -304,7 +306,7 @@ pub fn encrypt_in_place( cipher .encrypt_auth_inplace(ad, data, tag) .map(|(len, _)| len) - .map_err(|_e| Error::TLSStack) + .map_err(|_e| ErrorCode::TLSStack.into()) } pub fn decrypt_in_place( @@ -326,7 +328,7 @@ pub fn decrypt_in_place( .map(|(len, _)| len) .map_err(|e| { error!("Error during decryption: {:?}", e); - Error::TLSStack + ErrorCode::TLSStack.into() }) } @@ -343,12 +345,12 @@ impl Sha256 { } pub fn update(&mut self, data: &[u8]) -> Result<(), Error> { - self.ctx.update(data).map_err(|_| Error::TLSStack)?; + self.ctx.update(data).map_err(|_| ErrorCode::TLSStack)?; Ok(()) } pub fn finish(self, digest: &mut [u8]) -> Result<(), Error> { - self.ctx.finish(digest).map_err(|_| Error::TLSStack)?; + self.ctx.finish(digest).map_err(|_| ErrorCode::TLSStack)?; Ok(()) } } diff --git a/matter/src/crypto/crypto_openssl.rs b/matter/src/crypto/crypto_openssl.rs index e4486192..5343c528 100644 --- a/matter/src/crypto/crypto_openssl.rs +++ b/matter/src/crypto/crypto_openssl.rs @@ -15,7 +15,7 @@ * limitations under the License. */ -use crate::error::Error; +use crate::error::{Error, ErrorCode}; use foreign_types::ForeignTypeRef; use log::error; @@ -46,7 +46,8 @@ pub struct HmacSha256 { impl HmacSha256 { pub fn new(key: &[u8]) -> Result { Ok(Self { - ctx: Hmac::::new_from_slice(key).map_err(|_x| Error::InvalidKeyLength)?, + ctx: Hmac::::new_from_slice(key) + .map_err(|_x| ErrorCode::InvalidKeyLength)?, }) } @@ -107,7 +108,7 @@ impl KeyPair { fn private_key(&self) -> Result<&EcKey, Error> { match &self.key { - KeyType::Public(_) => Err(Error::Invalid), + KeyType::Public(_) => Err(ErrorCode::Invalid.into()), KeyType::Private(k) => Ok(&k), } } @@ -167,7 +168,7 @@ impl KeyPair { a.copy_from_slice(csr); Ok(a) } else { - Err(Error::NoSpace) + Err(ErrorCode::NoSpace.into()) } } @@ -178,7 +179,7 @@ impl KeyPair { let msg = h.finish()?; if signature.len() < super::EC_SIGNATURE_LEN_BYTES { - return Err(Error::NoSpace); + Err(ErrorCode::NoSpace)?; } safemem::write_bytes(signature, 0); @@ -205,11 +206,11 @@ impl KeyPair { KeyType::Public(key) => key, _ => { error!("Not yet supported"); - return Err(Error::Invalid); + Err(ErrorCode::Invalid)?; } }; if !sig.verify(&msg, k)? { - Err(Error::InvalidSignature) + Err(ErrorCode::InvalidSignature.into()) } else { Ok(()) } @@ -220,7 +221,7 @@ const P256_KEY_LEN: usize = 256 / 8; pub fn pubkey_from_der<'a>(der: &'a [u8], out_key: &mut [u8]) -> Result<(), Error> { if out_key.len() != P256_KEY_LEN { error!("Insufficient length"); - Err(Error::NoSpace) + Err(ErrorCode::NoSpace.into()) } else { let key = X509::from_der(der)?.public_key()?.public_key_to_der()?; let len = key.len(); @@ -232,7 +233,7 @@ pub fn pubkey_from_der<'a>(der: &'a [u8], out_key: &mut [u8]) -> Result<(), Erro pub fn pbkdf2_hmac(pass: &[u8], iter: usize, salt: &[u8], key: &mut [u8]) -> Result<(), Error> { openssl::pkcs5::pbkdf2_hmac(pass, salt, iter, MessageDigest::sha256(), key) - .map_err(|_e| Error::TLSStack) + .map_err(|_e| ErrorCode::TLSStack.into()) } pub fn hkdf_sha256(salt: &[u8], ikm: &[u8], info: &[u8], key: &mut [u8]) -> Result<(), Error> { @@ -372,7 +373,9 @@ impl Sha256 { } pub fn update(&mut self, data: &[u8]) -> Result<(), Error> { - self.hasher.update(data).map_err(|_| Error::TLSStack) + self.hasher + .update(data) + .map_err(|_| ErrorCode::TLSStack.into()) } pub fn finish(mut self, data: &mut [u8]) -> Result<(), Error> { diff --git a/matter/src/crypto/crypto_rustcrypto.rs b/matter/src/crypto/crypto_rustcrypto.rs index f64cbc47..b9aa3101 100644 --- a/matter/src/crypto/crypto_rustcrypto.rs +++ b/matter/src/crypto/crypto_rustcrypto.rs @@ -39,7 +39,7 @@ use x509_cert::{ spki::{AlgorithmIdentifier, SubjectPublicKeyInfoOwned}, }; -use crate::error::Error; +use crate::error::{Error, ErrorCode}; use super::CryptoKeyPair; @@ -79,7 +79,7 @@ impl HmacSha256 { Ok(Self { inner: HmacSha256I::new_from_slice(key).map_err(|e| { error!("Error creating HmacSha256 {:?}", e); - Error::TLSStack + ErrorCode::TLSStack })?, }) } @@ -143,7 +143,7 @@ impl KeyPair { fn private_key(&self) -> Result<&SecretKey, Error> { match &self.key { KeyType::Private(key) => Ok(key), - KeyType::Public(_) => Err(Error::Crypto), + KeyType::Public(_) => Err(ErrorCode::Crypto.into()), } } } @@ -158,7 +158,7 @@ impl CryptoKeyPair for KeyPair { priv_key[..slice.len()].copy_from_slice(slice); Ok(len) } - KeyType::Public(_) => Err(Error::Crypto), + KeyType::Public(_) => Err(ErrorCode::Crypto.into()), } } fn get_csr<'a>(&self, out_csr: &'a mut [u8]) -> Result<&'a [u8], Error> { @@ -251,7 +251,7 @@ impl CryptoKeyPair for KeyPair { use p256::ecdsa::signature::Signer; if signature.len() < super::EC_SIGNATURE_LEN_BYTES { - return Err(Error::NoSpace); + return Err(ErrorCode::NoSpace.into()); } match &self.key { @@ -274,7 +274,7 @@ impl CryptoKeyPair for KeyPair { verifying_key .verify(msg, &signature) - .map_err(|_| Error::InvalidSignature)?; + .map_err(|_| ErrorCode::InvalidSignature)?; Ok(()) } @@ -291,7 +291,7 @@ pub fn hkdf_sha256(salt: &[u8], ikm: &[u8], info: &[u8], key: &mut [u8]) -> Resu .expand(info, key) .map_err(|e| { error!("Error with hkdf_sha256 {:?}", e); - Error::TLSStack + ErrorCode::TLSStack.into() }) } diff --git a/matter/src/crypto/mod.rs b/matter/src/crypto/mod.rs index 3b7b4c4e..47d49b72 100644 --- a/matter/src/crypto/mod.rs +++ b/matter/src/crypto/mod.rs @@ -15,7 +15,7 @@ * limitations under the License. */ use crate::{ - error::Error, + error::{Error, ErrorCode}, tlv::{FromTLV, TLVWriter, TagType, ToTLV}, }; @@ -80,12 +80,12 @@ impl<'a> FromTLV<'a> for KeyPair { t.confirm_array()?.enter(); if let Some(mut array) = t.enter() { - let pub_key = array.next().ok_or(Error::Invalid)?.slice()?; - let priv_key = array.next().ok_or(Error::Invalid)?.slice()?; + let pub_key = array.next().ok_or(ErrorCode::Invalid)?.slice()?; + let priv_key = array.next().ok_or(ErrorCode::Invalid)?.slice()?; KeyPair::new_from_components(pub_key, priv_key) } else { - Err(Error::Invalid) + Err(ErrorCode::Invalid.into()) } } } @@ -108,7 +108,7 @@ impl ToTLV for KeyPair { #[cfg(test)] mod tests { - use crate::error::Error; + use crate::error::ErrorCode; use super::KeyPair; @@ -122,8 +122,9 @@ mod tests { fn test_verify_msg_fail() { let key = KeyPair::new_from_public(&test_vectors::PUB_KEY1).unwrap(); assert_eq!( - key.verify_msg(&test_vectors::MSG1_FAIL, &test_vectors::SIGNATURE1), - Err(Error::InvalidSignature) + key.verify_msg(&test_vectors::MSG1_FAIL, &test_vectors::SIGNATURE1) + .map_err(|e| e.code()), + Err(ErrorCode::InvalidSignature) ); } diff --git a/matter/src/data_model/cluster_template.rs b/matter/src/data_model/cluster_template.rs index c103812f..1e6adb8b 100644 --- a/matter/src/data_model/cluster_template.rs +++ b/matter/src/data_model/cluster_template.rs @@ -17,7 +17,7 @@ use crate::{ data_model::objects::{Cluster, Handler}, - error::Error, + error::{Error, ErrorCode}, utils::rand::Rand, }; @@ -51,7 +51,7 @@ impl TemplateCluster { if attr.is_system() { CLUSTER.read(attr.attr_id, writer) } else { - Err(Error::AttributeNotFound) + Err(ErrorCode::AttributeNotFound.into()) } } else { Ok(()) diff --git a/matter/src/data_model/objects/cluster.rs b/matter/src/data_model/objects/cluster.rs index 3818f93b..f9f4c5cf 100644 --- a/matter/src/data_model/objects/cluster.rs +++ b/matter/src/data_model/objects/cluster.rs @@ -22,7 +22,7 @@ use crate::{ acl::{AccessReq, Accessor}, attribute_enum, data_model::objects::*, - error::Error, + error::{Error, ErrorCode}, interaction_model::{ core::IMStatusCode, messages::{ @@ -320,7 +320,7 @@ impl<'a> Cluster<'a> { GlobalElements::FeatureMap => writer.set(self.feature_map), other => { error!("This attribute is not yet handled {:?}", other); - Err(Error::AttributeNotFound) + Err(ErrorCode::AttributeNotFound.into()) } } } diff --git a/matter/src/data_model/objects/encoder.rs b/matter/src/data_model/objects/encoder.rs index d068ce7b..e97eea0e 100644 --- a/matter/src/data_model/objects/encoder.rs +++ b/matter/src/data_model/objects/encoder.rs @@ -26,7 +26,7 @@ use crate::interaction_model::messages::ib::{ use crate::interaction_model::messages::GenericPath; use crate::tlv::UtfStr; use crate::{ - error::Error, + error::{Error, ErrorCode}, interaction_model::messages::ib::{AttrDataTag, AttrRespTag}, tlv::{FromTLV, TLVElement, TLVWriter, TagType, ToTLV}, }; @@ -135,8 +135,13 @@ impl<'a, 'b, 'c> AttrDataEncoder<'a, 'b, 'c> { match handler.read(&attr, encoder) { Ok(()) => None, - Err(Error::NoSpace) => return Ok(Some(attr.path().to_gp())), - Err(error) => attr.status(error.into())?, + Err(e) => { + if e.code() == ErrorCode::NoSpace { + return Ok(Some(attr.path().to_gp())); + } else { + attr.status(e.into())? + } + } } } Err(status) => Some(status), @@ -181,8 +186,13 @@ impl<'a, 'b, 'c> AttrDataEncoder<'a, 'b, 'c> { match handler.read(&attr, encoder).await { Ok(()) => None, - Err(Error::NoSpace) => return Ok(Some(attr.path().to_gp())), - Err(error) => attr.status(error.into())?, + Err(e) => { + if e.code() == ErrorCode::NoSpace { + return Ok(Some(attr.path().to_gp())); + } else { + attr.status(e.into())? + } + } } } Err(status) => Some(status), @@ -321,7 +331,7 @@ impl<'a> AttrData<'a> { pub fn with_dataver(self, dataver: u32) -> Result<&'a TLVElement<'a>, Error> { if let Some(req_dataver) = self.for_dataver { if req_dataver != dataver { - return Err(Error::DataVersionMismatch); + Err(ErrorCode::DataVersionMismatch)?; } } @@ -557,7 +567,8 @@ macro_rules! attribute_enum { type Error = $crate::error::Error; fn try_from(id: $crate::data_model::objects::AttrId) -> Result { - <$en>::from_repr(id).ok_or($crate::error::Error::AttributeNotFound) + <$en>::from_repr(id) + .ok_or_else(|| $crate::error::ErrorCode::AttributeNotFound.into()) } } }; @@ -571,7 +582,7 @@ macro_rules! command_enum { type Error = $crate::error::Error; fn try_from(id: $crate::data_model::objects::CmdId) -> Result { - <$en>::from_repr(id).ok_or($crate::error::Error::CommandNotFound) + <$en>::from_repr(id).ok_or_else(|| $crate::error::ErrorCode::CommandNotFound.into()) } } }; diff --git a/matter/src/data_model/objects/handler.rs b/matter/src/data_model/objects/handler.rs index 7758427f..a5e2b9c9 100644 --- a/matter/src/data_model/objects/handler.rs +++ b/matter/src/data_model/objects/handler.rs @@ -15,7 +15,11 @@ * limitations under the License. */ -use crate::{error::Error, interaction_model::core::Transaction, tlv::TLVElement}; +use crate::{ + error::{Error, ErrorCode}, + interaction_model::core::Transaction, + tlv::TLVElement, +}; use super::{AttrData, AttrDataEncoder, AttrDetails, CmdDataEncoder, CmdDetails}; @@ -27,7 +31,7 @@ pub trait Handler { fn read(&self, attr: &AttrDetails, encoder: AttrDataEncoder) -> Result<(), Error>; fn write(&mut self, _attr: &AttrDetails, _data: AttrData) -> Result<(), Error> { - Err(Error::AttributeNotFound) + Err(ErrorCode::AttributeNotFound.into()) } fn invoke( @@ -37,7 +41,7 @@ pub trait Handler { _data: &TLVElement, _encoder: CmdDataEncoder, ) -> Result<(), Error> { - Err(Error::CommandNotFound) + Err(ErrorCode::CommandNotFound.into()) } } @@ -88,7 +92,7 @@ impl EmptyHandler { impl Handler for EmptyHandler { fn read(&self, _attr: &AttrDetails, _encoder: AttrDataEncoder) -> Result<(), Error> { - Err(Error::AttributeNotFound) + Err(ErrorCode::AttributeNotFound.into()) } } @@ -202,7 +206,7 @@ macro_rules! handler_chain_type { pub mod asynch { use crate::{ data_model::objects::{AttrData, AttrDataEncoder, AttrDetails, CmdDataEncoder, CmdDetails}, - error::Error, + error::{Error, ErrorCode}, interaction_model::core::Transaction, tlv::TLVElement, }; @@ -221,7 +225,7 @@ pub mod asynch { _attr: &'a AttrDetails<'_>, _data: AttrData<'a>, ) -> Result<(), Error> { - Err(Error::AttributeNotFound) + Err(ErrorCode::AttributeNotFound.into()) } async fn invoke<'a>( @@ -231,7 +235,7 @@ pub mod asynch { _data: &'a TLVElement<'_>, _encoder: CmdDataEncoder<'a, '_, '_>, ) -> Result<(), Error> { - Err(Error::CommandNotFound) + Err(ErrorCode::CommandNotFound.into()) } } @@ -305,7 +309,7 @@ pub mod asynch { _attr: &'a AttrDetails<'_>, _encoder: AttrDataEncoder<'a, '_, '_>, ) -> Result<(), Error> { - Err(Error::AttributeNotFound) + Err(ErrorCode::AttributeNotFound.into()) } } diff --git a/matter/src/data_model/objects/privilege.rs b/matter/src/data_model/objects/privilege.rs index 6b4e3a5a..1032a450 100644 --- a/matter/src/data_model/objects/privilege.rs +++ b/matter/src/data_model/objects/privilege.rs @@ -16,7 +16,7 @@ */ use crate::{ - error::Error, + error::{Error, ErrorCode}, tlv::{FromTLV, TLVElement, ToTLV}, }; use log::error; @@ -47,12 +47,12 @@ impl FromTLV<'_> for Privilege { 1 => Ok(Privilege::VIEW), 2 => { error!("ProxyView privilege not yet supporteds"); - Err(Error::Invalid) + Err(ErrorCode::Invalid.into()) } 3 => Ok(Privilege::OPERATE), 4 => Ok(Privilege::MANAGE), 5 => Ok(Privilege::ADMIN), - _ => Err(Error::Invalid), + _ => Err(ErrorCode::Invalid.into()), } } } diff --git a/matter/src/data_model/sdm/admin_commissioning.rs b/matter/src/data_model/sdm/admin_commissioning.rs index 5497426b..b63aa2e1 100644 --- a/matter/src/data_model/sdm/admin_commissioning.rs +++ b/matter/src/data_model/sdm/admin_commissioning.rs @@ -144,7 +144,7 @@ impl<'a> AdminCommCluster<'a> { ) -> Result<(), Error> { match cmd.cmd_id.try_into()? { Commands::OpenCommWindow => self.handle_command_opencomm_win(data)?, - _ => Err(Error::CommandNotFound)?, + _ => Err(ErrorCode::CommandNotFound)?, } self.data_ver.changed(); diff --git a/matter/src/data_model/sdm/failsafe.rs b/matter/src/data_model/sdm/failsafe.rs index 5008c9fa..301baf91 100644 --- a/matter/src/data_model/sdm/failsafe.rs +++ b/matter/src/data_model/sdm/failsafe.rs @@ -15,7 +15,10 @@ * limitations under the License. */ -use crate::{error::Error, transport::session::SessionMode}; +use crate::{ + error::{Error, ErrorCode}, + transport::session::SessionMode, +}; use log::error; #[derive(PartialEq)] @@ -62,7 +65,7 @@ impl FailSafe { State::Armed(c) => { if c.session_mode != session_mode { error!("Received Fail-Safe Arm with different session modes; current {:?}, incoming {:?}", c.session_mode, session_mode); - return Err(Error::Invalid); + Err(ErrorCode::Invalid)?; } // re-arm c.timeout = timeout; @@ -75,22 +78,22 @@ impl FailSafe { match &mut self.state { State::Idle => { error!("Received Fail-Safe Disarm without it being armed"); - return Err(Error::Invalid); + Err(ErrorCode::Invalid)?; } State::Armed(c) => { match c.noc_state { - NocState::NocNotRecvd => return Err(Error::Invalid), + NocState::NocNotRecvd => Err(ErrorCode::Invalid)?, NocState::AddNocRecvd(idx) | NocState::UpdateNocRecvd(idx) => { if let SessionMode::Case(c) = session_mode { if c.fab_idx != idx { error!( "Received disarm in separate session from previous Add/Update NOC" ); - return Err(Error::Invalid); + Err(ErrorCode::Invalid)?; } } else { error!("Received disarm in a non-CASE session"); - return Err(Error::Invalid); + Err(ErrorCode::Invalid)?; } } } @@ -106,13 +109,13 @@ impl FailSafe { pub fn record_add_noc(&mut self, fabric_index: u8) -> Result<(), Error> { match &mut self.state { - State::Idle => Err(Error::Invalid), + State::Idle => Err(ErrorCode::Invalid.into()), State::Armed(c) => { if c.noc_state == NocState::NocNotRecvd { c.noc_state = NocState::AddNocRecvd(fabric_index); Ok(()) } else { - Err(Error::Invalid) + Err(ErrorCode::Invalid.into()) } } } diff --git a/matter/src/data_model/sdm/general_commissioning.rs b/matter/src/data_model/sdm/general_commissioning.rs index d4d43297..f2487ef2 100644 --- a/matter/src/data_model/sdm/general_commissioning.rs +++ b/matter/src/data_model/sdm/general_commissioning.rs @@ -235,9 +235,9 @@ impl<'a> GenCommCluster<'a> { cmd_enter!("Set Regulatory Config"); let country_code = data .find_tag(1) - .map_err(|_| Error::InvalidCommand)? + .map_err(|_| ErrorCode::InvalidCommand)? .slice() - .map_err(|_| Error::InvalidCommand)?; + .map_err(|_| ErrorCode::InvalidCommand)?; info!("Received country code: {:?}", country_code); let cmd_data = CommonResponse { diff --git a/matter/src/data_model/sdm/noc.rs b/matter/src/data_model/sdm/noc.rs index 634ba85f..6182c0c0 100644 --- a/matter/src/data_model/sdm/noc.rs +++ b/matter/src/data_model/sdm/noc.rs @@ -277,7 +277,7 @@ impl<'a> NocCluster<'a> { } _ => { error!("Attribute not supported: this shouldn't happen"); - Err(Error::AttributeNotFound) + Err(ErrorCode::AttributeNotFound.into()) } } } @@ -563,7 +563,7 @@ impl<'a> NocCluster<'a> { info!("Received CSR Nonce:{:?}", req.str); if !self.failsafe.borrow().is_armed() { - return Err(Error::UnsupportedAccess); + Err(ErrorCode::UnsupportedAccess)?; } let noc_keypair = KeyPair::new()?; @@ -602,7 +602,7 @@ impl<'a> NocCluster<'a> { ) -> Result<(), Error> { cmd_enter!("AddTrustedRootCert"); if !self.failsafe.borrow().is_armed() { - return Err(Error::UnsupportedAccess); + Err(ErrorCode::UnsupportedAccess)?; } // This may happen on CASE or PASE. For PASE, the existence of NOC Data is necessary @@ -612,13 +612,13 @@ impl<'a> NocCluster<'a> { let noc_data = transaction .session_mut() .get_noc_data::() - .ok_or(Error::NoSession)?; + .ok_or(ErrorCode::NoSession)?; let req = CommonReq::from_tlv(data).map_err(Error::map_invalid_command)?; info!("Received Trusted Cert:{:x?}", req.str); noc_data.root_ca = - heapless::Vec::from_slice(req.str.0).map_err(|_| Error::BufferTooSmall)?; + heapless::Vec::from_slice(req.str.0).map_err(|_| ErrorCode::BufferTooSmall)?; // TODO } _ => (), @@ -720,6 +720,6 @@ fn get_certchainrequest_params(data: &TLVElement) -> Result { match cert_type { CERT_TYPE_DAC => Ok(dev_att::DataType::DAC), CERT_TYPE_PAI => Ok(dev_att::DataType::PAI), - _ => Err(Error::Invalid), + _ => Err(ErrorCode::Invalid.into()), } } diff --git a/matter/src/data_model/sdm/nw_commissioning.rs b/matter/src/data_model/sdm/nw_commissioning.rs index 47ffe6ed..5abf8091 100644 --- a/matter/src/data_model/sdm/nw_commissioning.rs +++ b/matter/src/data_model/sdm/nw_commissioning.rs @@ -20,7 +20,7 @@ use crate::{ AttrDataEncoder, AttrDetails, ChangeNotifier, Cluster, Dataver, Handler, NonBlockingHandler, ATTRIBUTE_LIST, FEATURE_MAP, }, - error::Error, + error::{Error, ErrorCode}, utils::rand::Rand, }; @@ -57,7 +57,7 @@ impl Handler for NwCommCluster { if attr.is_system() { CLUSTER.read(attr.attr_id, writer) } else { - Err(Error::AttributeNotFound) + Err(ErrorCode::AttributeNotFound.into()) } } else { Ok(()) diff --git a/matter/src/data_model/system_model/access_control.rs b/matter/src/data_model/system_model/access_control.rs index ffba5e67..c57c0df2 100644 --- a/matter/src/data_model/system_model/access_control.rs +++ b/matter/src/data_model/system_model/access_control.rs @@ -141,7 +141,7 @@ impl<'a> AccessControlCluster<'a> { } _ => { error!("Attribute not yet supported: this shouldn't happen"); - Err(Error::AttributeNotFound) + Err(ErrorCode::AttributeNotFound.into()) } } } @@ -229,7 +229,7 @@ mod tests { // Test, ACL has fabric index 2, but the accessing fabric is 1 // the fabric index in the TLV should be ignored and the ACL should be created with entry 1 let result = acl.write_acl_attr(&ListOperation::AddItem, &data, 1); - assert_eq!(result, Ok(())); + assert!(result.is_ok()); let verifier = AclEntry::new(1, Privilege::VIEW, AuthMode::Case); acl_mgr @@ -268,7 +268,7 @@ mod tests { let result = acl.write_acl_attr(&ListOperation::EditItem(1), &data, 2); // Fabric 2's index 1, is actually our index 2, update the verifier verifier[2] = new; - assert_eq!(result, Ok(())); + assert!(result.is_ok()); // Also validate in the acl_mgr that the entries are in the right order let mut index = 0; @@ -301,7 +301,7 @@ mod tests { // Test , Delete Fabric 1's index 0 let result = acl.write_acl_attr(&ListOperation::DeleteItem(0), &data, 1); - assert_eq!(result, Ok(())); + assert!(result.is_ok()); let verifier = [input[0], input[2]]; // Also validate in the acl_mgr that the entries are in the right order diff --git a/matter/src/error.rs b/matter/src/error.rs index d2053d16..507ce4bf 100644 --- a/matter/src/error.rs +++ b/matter/src/error.rs @@ -17,8 +17,8 @@ use core::{array::TryFromSliceError, fmt, str::Utf8Error}; -#[derive(Debug, PartialEq, Clone, Copy)] -pub enum Error { +#[derive(Debug, PartialEq, Eq, Clone, Copy)] +pub enum ErrorCode { AttributeNotFound, AttributeIsCustom, BufferTooSmall, @@ -73,7 +73,36 @@ pub enum Error { Utf8Fail, } +impl From for Error { + fn from(code: ErrorCode) -> Self { + Self::new(code) + } +} + +pub struct Error { + code: ErrorCode, + #[cfg(all(feature = "std", feature = "backtrace"))] + backtrace: std::backtrace::Backtrace, +} + impl Error { + pub fn new(code: ErrorCode) -> Self { + Self { + code, + #[cfg(all(feature = "std", feature = "backtrace"))] + backtrace: std::backtrace::Backtrace::capture(), + } + } + + pub const fn code(&self) -> ErrorCode { + self.code + } + + #[cfg(all(feature = "std", feature = "backtrace"))] + pub const fn backtrace(&self) -> &std::backtrace::Backtrace { + &self.backtrace + } + pub fn remap(self, matcher: F, to: Self) -> Self where F: FnOnce(&Self) -> bool, @@ -86,19 +115,22 @@ impl Error { } pub fn map_invalid(self, to: Self) -> Self { - self.remap(|e| matches!(e, Self::Invalid | Self::InvalidData), to) + self.remap( + |e| matches!(e.code(), ErrorCode::Invalid | ErrorCode::InvalidData), + to, + ) } pub fn map_invalid_command(self) -> Self { - self.map_invalid(Error::InvalidCommand) + self.map_invalid(Error::new(ErrorCode::InvalidCommand)) } pub fn map_invalid_action(self) -> Self { - self.map_invalid(Error::InvalidAction) + self.map_invalid(Error::new(ErrorCode::InvalidAction)) } pub fn map_invalid_data_type(self) -> Self { - self.map_invalid(Error::InvalidDataType) + self.map_invalid(Error::new(ErrorCode::InvalidDataType)) } } @@ -106,14 +138,14 @@ impl Error { impl From for Error { fn from(_e: std::io::Error) -> Self { // Keep things simple for now - Self::StdIoError + Self::new(ErrorCode::StdIoError) } } #[cfg(feature = "std")] impl From> for Error { fn from(_e: std::sync::PoisonError) -> Self { - Self::RwLock + Self::new(ErrorCode::RwLock) } } @@ -121,7 +153,7 @@ impl From> for Error { impl From for Error { fn from(e: openssl::error::ErrorStack) -> Self { ::log::error!("Error in TLS: {}", e); - Self::TLSStack + Self::new(ErrorCode::TLSStack) } } @@ -129,39 +161,57 @@ impl From for Error { impl From for Error { fn from(e: mbedtls::Error) -> Self { ::log::error!("Error in TLS: {}", e); - Self::TLSStack + Self::new(ErrorCode::TLSStack) } } #[cfg(feature = "crypto_rustcrypto")] impl From for Error { fn from(_e: ccm::aead::Error) -> Self { - Self::Crypto + Self::new(ErrorCode::Crypto) } } #[cfg(feature = "std")] impl From for Error { fn from(_e: std::time::SystemTimeError) -> Self { - Self::SysTimeFail + Error::new(ErrorCode::SysTimeFail) } } impl From for Error { fn from(_e: TryFromSliceError) -> Self { - Self::Invalid + Self::new(ErrorCode::Invalid) } } impl From for Error { fn from(_e: Utf8Error) -> Self { - Self::Utf8Fail + Self::new(ErrorCode::Utf8Fail) + } +} + +impl fmt::Debug for Error { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + #[cfg(not(all(feature = "std", feature = "backtrace")))] + { + write!(f, "Error::{}", self)?; + } + + #[cfg(all(feature = "std", feature = "backtrace"))] + { + write!(f, "Error::{} {{\n", self)?; + write!(f, "{}", self.backtrace())?; + write!(f, "}}\n")?; + } + + Ok(()) } } impl fmt::Display for Error { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - write!(f, "{:?}", self) + write!(f, "{:?}", self.code()) } } diff --git a/matter/src/fabric.rs b/matter/src/fabric.rs index 6f3ff0ee..5658d4c1 100644 --- a/matter/src/fabric.rs +++ b/matter/src/fabric.rs @@ -24,10 +24,10 @@ use log::info; use crate::{ cert::{Cert, MAX_CERT_TLV_LEN}, crypto::{self, hkdf_sha256, HmacSha256, KeyPair}, - error::Error, + error::{Error, ErrorCode}, group_keys::KeySet, mdns::{MdnsMgr, ServiceMode}, - tlv::{FromTLV, OctetStr, TLVElement, TLVList, TLVWriter, TagType, ToTLV, UtfStr}, + tlv::{FromTLV, OctetStr, TLVList, TLVWriter, TagType, ToTLV, UtfStr}, utils::writebuf::WriteBuf, }; @@ -123,7 +123,7 @@ impl Fabric { 0x69, 0x63, ]; hkdf_sha256(&fabric_id_be, root_pubkey, &COMPRESSED_FABRIC_ID_INFO, out) - .map_err(|_| Error::NoSpace) + .map_err(|_| Error::from(ErrorCode::NoSpace)) } pub fn match_dest_id(&self, random: &[u8], target: &[u8]) -> Result<(), Error> { @@ -144,7 +144,7 @@ impl Fabric { if id.as_slice() == target { Ok(()) } else { - Err(Error::NotFound) + Err(ErrorCode::NotFound.into()) } } @@ -208,7 +208,7 @@ impl FabricMgr { } } - let root = TLVList::new(data).iter().next().ok_or(Error::Invalid)?; + let root = TLVList::new(data).iter().next().ok_or(ErrorCode::Invalid)?; self.fabrics = FabricEntries::from_tlv(&root)?; @@ -227,6 +227,7 @@ impl FabricMgr { if self.changed { let mut wb = WriteBuf::new(buf); let mut tw = TLVWriter::new(&mut wb); + self.fabrics.to_tlv(&mut tw, TagType::Anonymous)?; self.changed = false; @@ -252,7 +253,7 @@ impl FabricMgr { } } - Err(Error::NoSpace) + Err(ErrorCode::NoSpace.into()) } pub fn remove(&mut self, fab_idx: u8, mdns_mgr: &mut MdnsMgr) -> Result<(), Error> { @@ -262,10 +263,10 @@ impl FabricMgr { self.changed = true; Ok(()) } else { - Err(Error::NotFound) + Err(ErrorCode::NotFound.into()) } } else { - Err(Error::NotFound) + Err(ErrorCode::NotFound.into()) } } @@ -277,7 +278,7 @@ impl FabricMgr { } } } - Err(Error::NotFound) + Err(ErrorCode::NotFound.into()) } pub fn get_fabric(&self, idx: usize) -> Result, Error> { @@ -317,7 +318,7 @@ impl FabricMgr { .filter_map(|f| f.as_ref()) .any(|f| f.label == label) { - return Err(Error::Invalid); + return Err(ErrorCode::Invalid.into()); } } diff --git a/matter/src/group_keys.rs b/matter/src/group_keys.rs index d4e97659..7b584e14 100644 --- a/matter/src/group_keys.rs +++ b/matter/src/group_keys.rs @@ -17,8 +17,8 @@ use crate::{ crypto::{self, SYMM_KEY_LEN_BYTES}, - error::Error, - tlv::{FromTLV, TLVElement, TLVWriter, TagType, ToTLV}, + error::{Error, ErrorCode}, + tlv::{FromTLV, TLVWriter, TagType, ToTLV}, }; type KeySetKey = [u8; SYMM_KEY_LEN_BYTES]; @@ -42,7 +42,8 @@ impl KeySet { 0x47, 0x72, 0x6f, 0x75, 0x70, 0x4b, 0x65, 0x79, 0x20, 0x76, 0x31, 0x2e, 0x30, ]; - crypto::hkdf_sha256(compressed_id, ipk, &GRP_KEY_INFO, opkey).map_err(|_| Error::NoSpace) + crypto::hkdf_sha256(compressed_id, ipk, &GRP_KEY_INFO, opkey) + .map_err(|_| ErrorCode::NoSpace.into()) } pub fn op_key(&self) -> &[u8] { diff --git a/matter/src/interaction_model/core.rs b/matter/src/interaction_model/core.rs index 2ce9c82b..9e29bac4 100644 --- a/matter/src/interaction_model/core.rs +++ b/matter/src/interaction_model/core.rs @@ -78,27 +78,33 @@ pub enum IMStatusCode { FailSafeRequired = 0xca, } -impl From for IMStatusCode { - fn from(e: Error) -> Self { +impl From for IMStatusCode { + fn from(e: ErrorCode) -> Self { match e { - Error::EndpointNotFound => IMStatusCode::UnsupportedEndpoint, - Error::ClusterNotFound => IMStatusCode::UnsupportedCluster, - Error::AttributeNotFound => IMStatusCode::UnsupportedAttribute, - Error::CommandNotFound => IMStatusCode::UnsupportedCommand, - Error::InvalidAction => IMStatusCode::InvalidAction, - Error::InvalidCommand => IMStatusCode::InvalidCommand, - Error::UnsupportedAccess => IMStatusCode::UnsupportedAccess, - Error::Busy => IMStatusCode::Busy, - Error::DataVersionMismatch => IMStatusCode::DataVersionMismatch, - Error::ResourceExhausted => IMStatusCode::ResourceExhausted, + ErrorCode::EndpointNotFound => IMStatusCode::UnsupportedEndpoint, + ErrorCode::ClusterNotFound => IMStatusCode::UnsupportedCluster, + ErrorCode::AttributeNotFound => IMStatusCode::UnsupportedAttribute, + ErrorCode::CommandNotFound => IMStatusCode::UnsupportedCommand, + ErrorCode::InvalidAction => IMStatusCode::InvalidAction, + ErrorCode::InvalidCommand => IMStatusCode::InvalidCommand, + ErrorCode::UnsupportedAccess => IMStatusCode::UnsupportedAccess, + ErrorCode::Busy => IMStatusCode::Busy, + ErrorCode::DataVersionMismatch => IMStatusCode::DataVersionMismatch, + ErrorCode::ResourceExhausted => IMStatusCode::ResourceExhausted, _ => IMStatusCode::Failure, } } } +impl From for IMStatusCode { + fn from(value: Error) -> Self { + Self::from(value.code()) + } +} + impl FromTLV<'_> for IMStatusCode { fn from_tlv(t: &TLVElement) -> Result { - num::FromPrimitive::from_u16(t.u16()?).ok_or(Error::Invalid) + num::FromPrimitive::from_u16(t.u16()?).ok_or_else(|| ErrorCode::Invalid.into()) } } @@ -223,7 +229,7 @@ pub enum Interaction<'a> { impl<'a> Interaction<'a> { fn new(rx: &'a Packet, transaction: &mut Transaction) -> Result, Error> { let opcode: OpCode = - num::FromPrimitive::from_u8(rx.get_proto_opcode()).ok_or(Error::Invalid)?; + num::FromPrimitive::from_u8(rx.get_proto_opcode()).ok_or(ErrorCode::Invalid)?; let rx_data = rx.as_slice(); @@ -264,7 +270,7 @@ impl<'a> Interaction<'a> { )?))), _ => { error!("Opcode not handled: {:?}", opcode); - Err(Error::InvalidOpcode) + Err(ErrorCode::InvalidOpcode.into()) } } } diff --git a/matter/src/interaction_model/messages.rs b/matter/src/interaction_model/messages.rs index 19e29f18..edf65db9 100644 --- a/matter/src/interaction_model/messages.rs +++ b/matter/src/interaction_model/messages.rs @@ -17,8 +17,8 @@ use crate::{ data_model::objects::{ClusterId, EndptId}, - error::Error, - tlv::{FromTLV, TLVElement, TLVWriter, TagType, ToTLV}, + error::{Error, ErrorCode}, + tlv::{FromTLV, TLVWriter, TagType, ToTLV}, }; // A generic path with endpoint, clusters, and a leaf @@ -48,7 +48,7 @@ impl GenericPath { cluster: Some(c), leaf: Some(l), } => Ok((e, c, l)), - _ => Err(Error::Invalid), + _ => Err(ErrorCode::Invalid.into()), } } /// Returns true, if the path is wildcard @@ -69,7 +69,7 @@ pub mod msg { use crate::{ error::Error, interaction_model::core::IMStatusCode, - tlv::{FromTLV, TLVArray, TLVElement, TLVWriter, TagType, ToTLV}, + tlv::{FromTLV, TLVArray, TLVWriter, TagType, ToTLV}, }; use super::ib::{ @@ -259,7 +259,7 @@ pub mod ib { use crate::{ data_model::objects::{AttrDetails, AttrId, ClusterId, CmdId, EncodeValue, EndptId}, - error::Error, + error::{Error, ErrorCode}, interaction_model::core::IMStatusCode, tlv::{FromTLV, Nullable, TLVElement, TLVWriter, TagType, ToTLV}, }; @@ -447,7 +447,7 @@ pub mod ib { f(ListOperation::DeleteList, data)?; // Now the data must be a list, that should be added item by item - let container = data.enter().ok_or(Error::Invalid)?; + let container = data.enter().ok_or(ErrorCode::Invalid)?; for d in container { f(ListOperation::AddItem, &d)?; } @@ -544,7 +544,7 @@ pub mod ib { if c.path.leaf.is_none() { error!("Wildcard command parameter not supported"); - Err(Error::CommandNotFound) + Err(ErrorCode::CommandNotFound.into()) } else { Ok(c) } diff --git a/matter/src/mdns.rs b/matter/src/mdns.rs index 1b296187..eb19a9f1 100644 --- a/matter/src/mdns.rs +++ b/matter/src/mdns.rs @@ -217,7 +217,7 @@ pub mod astro { use std::collections::HashMap; use super::Mdns; - use crate::error::Error; + use crate::error::{Error, ErrorCode}; use astro_dnssd::{DNSServiceBuilder, RegisteredDnsService}; use log::info; @@ -269,7 +269,7 @@ pub mod astro { builder = builder.with_key_value(kvs.0.to_string(), kvs.1.to_string()); } - let svc = builder.register().map_err(|_| Error::MdnsError)?; + let svc = builder.register().map_err(|_| ErrorCode::MdnsError)?; self.services.insert( ServiceId { @@ -348,7 +348,7 @@ pub mod astro { // use std::collections::HashMap; // use super::Mdns; -// use crate::error::Error; +// use crate::error::{Error, ErrorCode}; // use log::info; // use zeroconf::prelude::*; // use zeroconf::{MdnsService, ServiceType, TxtRecord}; @@ -402,7 +402,7 @@ pub mod astro { // svc.set_txt_record(txt); -// //let event_loop = svc.register().map_err(|_| Error::MdnsError)?; +// //let event_loop = svc.register().map_err(|_| ErrorCode::MdnsError)?; // self.services.insert( // ServiceId { @@ -604,7 +604,7 @@ pub mod libmdns { // pub mod simplemdns { // use std::net::Ipv4Addr; -// use crate::error::Error; +// use crate::error::{Error, ErrorCode}; // use super::Mdns; // use log::info; // use simple_dns::{ diff --git a/matter/src/pairing/qr.rs b/matter/src/pairing/qr.rs index e99d909c..aa264811 100644 --- a/matter/src/pairing/qr.rs +++ b/matter/src/pairing/qr.rs @@ -16,6 +16,7 @@ */ use crate::{ + error::ErrorCode, tlv::{TLVWriter, TagType}, utils::writebuf::WriteBuf, }; @@ -134,7 +135,7 @@ impl<'data> QrSetupPayload<'data> { if is_vendor_tag(tag) { self.add_optional_data(tag, data) } else { - Err(Error::InvalidArgument) + Err(ErrorCode::InvalidArgument.into()) } } @@ -150,7 +151,7 @@ impl<'data> QrSetupPayload<'data> { if is_common_tag(tag) { self.add_optional_data(tag, data) } else { - Err(Error::InvalidArgument) + Err(ErrorCode::InvalidArgument.into()) } } @@ -163,7 +164,7 @@ impl<'data> QrSetupPayload<'data> { } else { self.optional_data.push(item) } - .map_err(|_| Error::NoSpace) + .map_err(|_| ErrorCode::NoSpace.into()) } pub fn get_all_optional_data(&self) -> &[OptionalQRCodeInfo] { @@ -267,7 +268,7 @@ pub(super) fn payload_base38_representation( payload_base38_representation_with_tlv(payload, bits_buf, tlv_buf) } else { - Err(Error::InvalidArgument) + Err(ErrorCode::InvalidArgument.into()) } } @@ -299,7 +300,7 @@ pub fn estimate_buffer_size(payload: &QrSetupPayload) -> Result { estimate = estimate_struct_overhead(estimate); if estimate > u32::MAX as usize { - return Err(Error::NoMemory); + Err(ErrorCode::NoMemory)?; } Ok(estimate) @@ -352,11 +353,11 @@ fn populate_bits( total_payload_data_size_in_bits: usize, ) -> Result<(), Error> { if *offset + number_of_bits > total_payload_data_size_in_bits { - return Err(Error::InvalidArgument); + Err(ErrorCode::InvalidArgument)?; } if input >= 1u64 << number_of_bits { - return Err(Error::InvalidArgument); + Err(ErrorCode::InvalidArgument)?; } let mut index = *offset; @@ -390,7 +391,7 @@ fn payload_base38_representation_with_tlv( let mut base38_encoded: heapless::String = "MT:".into(); for c in base38::encode(bits) { - base38_encoded.push(c).map_err(|_| Error::NoSpace)?; + base38_encoded.push(c).map_err(|_| ErrorCode::NoSpace)?; } Ok(base38_encoded) @@ -431,7 +432,7 @@ fn generate_bit_set<'a>( TOTAL_PAYLOAD_DATA_SIZE_IN_BITS + tlv_data.map(|tlv_data| tlv_data.len() * 8).unwrap_or(0); if bits_buf.len() * 8 < total_payload_size_in_bits { - return Err(Error::BufferTooSmall); + Err(ErrorCode::BufferTooSmall)?; }; let passwd = passwd_from_comm_data(payload.comm_data); diff --git a/matter/src/persist.rs b/matter/src/persist.rs index 53e413eb..d9a27330 100644 --- a/matter/src/persist.rs +++ b/matter/src/persist.rs @@ -25,7 +25,7 @@ mod file_psm { use log::info; - use crate::error::Error; + use crate::error::{Error, ErrorCode}; pub struct FilePsm { dir: PathBuf, @@ -47,7 +47,7 @@ mod file_psm { loop { if offset == buf.len() { - return Err(Error::NoSpace); + Err(ErrorCode::NoSpace)?; } let len = file.read(&mut buf[offset..])?; diff --git a/matter/src/secure_channel/case.rs b/matter/src/secure_channel/case.rs index e681ec92..18011c99 100644 --- a/matter/src/secure_channel/case.rs +++ b/matter/src/secure_channel/case.rs @@ -22,11 +22,11 @@ use log::{error, trace}; use crate::{ cert::Cert, crypto::{self, KeyPair, Sha256}, - error::Error, + error::{Error, ErrorCode}, fabric::{Fabric, FabricMgr}, secure_channel::common::SCStatusCodes, secure_channel::common::{self, OpCode}, - tlv::{get_root_node_struct, FromTLV, OctetStr, TLVElement, TLVWriter, TagType}, + tlv::{get_root_node_struct, FromTLV, OctetStr, TLVWriter, TagType}, transport::{ network::Address, proto_ctx::ProtoCtx, @@ -90,9 +90,9 @@ impl<'a> Case<'a> { .exch_ctx .exch .take_case_session::() - .ok_or(Error::InvalidState)?; + .ok_or(ErrorCode::InvalidState)?; if case_session.state != State::Sigma1Rx { - return Err(Error::Invalid); + Err(ErrorCode::Invalid)?; } case_session.state = State::Sigma3Rx; @@ -117,7 +117,7 @@ impl<'a> Case<'a> { let mut decrypted: [u8; 800] = [0; 800]; if encrypted.len() > decrypted.len() { error!("Data too large"); - return Err(Error::NoSpace); + Err(ErrorCode::NoSpace)?; } let decrypted = &mut decrypted[..encrypted.len()]; decrypted.copy_from_slice(encrypted); @@ -204,7 +204,7 @@ impl<'a> Case<'a> { case_session.local_fabric_idx = local_fabric_idx?; if r.peer_pub_key.0.len() != crypto::EC_POINT_LEN_BYTES { error!("Invalid public key length"); - return Err(Error::Invalid); + Err(ErrorCode::Invalid)?; } case_session.peer_pub_key.copy_from_slice(r.peer_pub_key.0); trace!( @@ -220,7 +220,7 @@ impl<'a> Case<'a> { let len = key_pair.derive_secret(r.peer_pub_key.0, &mut case_session.shared_secret)?; if len != 32 { error!("Derived secret length incorrect"); - return Err(Error::Invalid); + Err(ErrorCode::Invalid)?; } // println!("Derived secret: {:x?} len: {}", secret, len); @@ -348,14 +348,14 @@ impl<'a> Case<'a> { let mut verifier = noc.verify_chain_start(utc_calendar); if fabric.get_fabric_id() != noc.get_fabric_id()? { - return Err(Error::Invalid); + Err(ErrorCode::Invalid)?; } if let Some(icac) = icac { // If ICAC is present handle it if let Ok(fid) = icac.get_fabric_id() { if fid != fabric.get_fabric_id() { - return Err(Error::Invalid); + Err(ErrorCode::Invalid)?; } } verifier = verifier.add_cert(icac)?; @@ -377,7 +377,7 @@ impl<'a> Case<'a> { 0x53, 0x65, 0x73, 0x73, 0x69, 0x6f, 0x6e, 0x4b, 0x65, 0x79, 0x73, ]; if key.len() < 48 { - return Err(Error::NoSpace); + Err(ErrorCode::NoSpace)?; } let mut salt = heapless::Vec::::new(); salt.extend_from_slice(ipk).unwrap(); @@ -388,7 +388,7 @@ impl<'a> Case<'a> { // println!("Session Key: salt: {:x?}, len: {}", salt, salt.len()); crypto::hkdf_sha256(salt.as_slice(), shared_secret, &SEKEYS_INFO, key) - .map_err(|_x| Error::NoSpace)?; + .map_err(|_x| ErrorCode::NoSpace)?; // println!("Session Key: key: {:x?}", key); Ok(()) @@ -425,7 +425,7 @@ impl<'a> Case<'a> { ) -> Result<(), Error> { const S3K_INFO: [u8; 6] = [0x53, 0x69, 0x67, 0x6d, 0x61, 0x33]; if key.len() < 16 { - return Err(Error::NoSpace); + Err(ErrorCode::NoSpace)?; } let mut salt = heapless::Vec::::new(); salt.extend_from_slice(ipk).unwrap(); @@ -438,7 +438,7 @@ impl<'a> Case<'a> { // println!("Sigma3Key: salt: {:x?}, len: {}", salt, salt.len()); crypto::hkdf_sha256(salt.as_slice(), shared_secret, &S3K_INFO, key) - .map_err(|_x| Error::NoSpace)?; + .map_err(|_x| ErrorCode::NoSpace)?; // println!("Sigma3Key: key: {:x?}", key); Ok(()) @@ -452,7 +452,7 @@ impl<'a> Case<'a> { ) -> Result<(), Error> { const S2K_INFO: [u8; 6] = [0x53, 0x69, 0x67, 0x6d, 0x61, 0x32]; if key.len() < 16 { - return Err(Error::NoSpace); + Err(ErrorCode::NoSpace)?; } let mut salt = heapless::Vec::::new(); salt.extend_from_slice(ipk).unwrap(); @@ -467,7 +467,7 @@ impl<'a> Case<'a> { // println!("Sigma2Key: salt: {:x?}, len: {}", salt, salt.len()); crypto::hkdf_sha256(salt.as_slice(), &case_session.shared_secret, &S2K_INFO, key) - .map_err(|_x| Error::NoSpace)?; + .map_err(|_x| ErrorCode::NoSpace)?; // println!("Sigma2Key: key: {:x?}", key); Ok(()) diff --git a/matter/src/secure_channel/core.rs b/matter/src/secure_channel/core.rs index fd13206e..653ad741 100644 --- a/matter/src/secure_channel/core.rs +++ b/matter/src/secure_channel/core.rs @@ -57,7 +57,7 @@ impl<'a> SecureChannel<'a> { pub fn handle(&mut self, ctx: &mut ProtoCtx) -> Result<(bool, Option), Error> { let proto_opcode: OpCode = - num::FromPrimitive::from_u8(ctx.rx.get_proto_opcode()).ok_or(Error::Invalid)?; + num::FromPrimitive::from_u8(ctx.rx.get_proto_opcode()).ok_or(ErrorCode::Invalid)?; ctx.tx.set_proto_id(PROTO_ID_SECURE_CHANNEL); info!("Received Opcode: {:?}", proto_opcode); info!("Received Data:"); @@ -82,7 +82,7 @@ impl<'a> SecureChannel<'a> { OpCode::CASESigma3 => self.case.casesigma3_handler(ctx), _ => { error!("OpCode Not Handled: {:?}", proto_opcode); - Err(Error::InvalidOpcode) + Err(ErrorCode::InvalidOpcode.into()) } }?; diff --git a/matter/src/secure_channel/crypto_dummy.rs b/matter/src/secure_channel/crypto_dummy.rs index 11ec8523..3933e797 100644 --- a/matter/src/secure_channel/crypto_dummy.rs +++ b/matter/src/secure_channel/crypto_dummy.rs @@ -15,7 +15,7 @@ * limitations under the License. */ -use crate::error::Error; +use crate::error::{Error, ErrorCode}; #[allow(non_snake_case)] @@ -29,35 +29,35 @@ impl CryptoSpake2 { // Computes w0 from w0s respectively pub fn set_w0_from_w0s(&mut self, _w0s: &[u8]) -> Result<(), Error> { - Err(Error::Invalid) + Err(ErrorCode::Invalid.into()) } pub fn set_w1_from_w1s(&mut self, _w1s: &[u8]) -> Result<(), Error> { - Err(Error::Invalid) + Err(ErrorCode::Invalid.into()) } pub fn set_w0(&mut self, _w0: &[u8]) -> Result<(), Error> { - Err(Error::Invalid) + Err(ErrorCode::Invalid.into()) } pub fn set_w1(&mut self, _w1: &[u8]) -> Result<(), Error> { - Err(Error::Invalid) + Err(ErrorCode::Invalid.into()) } #[allow(non_snake_case)] pub fn set_L(&mut self, _l: &[u8]) -> Result<(), Error> { - Err(Error::Invalid) + Err(ErrorCode::Invalid.into()) } #[allow(non_snake_case)] #[allow(dead_code)] pub fn set_L_from_w1s(&mut self, _w1s: &[u8]) -> Result<(), Error> { - Err(Error::Invalid) + Err(ErrorCode::Invalid.into()) } #[allow(non_snake_case)] pub fn get_pB(&mut self, _pB: &mut [u8]) -> Result<(), Error> { - Err(Error::Invalid) + Err(ErrorCode::Invalid.into()) } #[allow(non_snake_case)] @@ -68,6 +68,6 @@ impl CryptoSpake2 { _pB: &[u8], _out: &mut [u8], ) -> Result<(), Error> { - Err(Error::Invalid) + Err(ErrorCode::Invalid.into()) } } diff --git a/matter/src/secure_channel/crypto_mbedtls.rs b/matter/src/secure_channel/crypto_mbedtls.rs index 27c9fc61..de7ea487 100644 --- a/matter/src/secure_channel/crypto_mbedtls.rs +++ b/matter/src/secure_channel/crypto_mbedtls.rs @@ -18,7 +18,7 @@ use alloc::sync::Arc; use core::ops::{Mul, Sub}; -use crate::error::Error; +use crate::error::{Error, ErrorCode}; use byteorder::{ByteOrder, LittleEndian}; use log::error; @@ -150,7 +150,7 @@ impl CryptoSpake2 { let pB_internal = pB_internal.as_slice(); if pB_internal.len() != pB.len() { error!("pB length mismatch"); - return Err(Error::Invalid); + Err(ErrorCode::Invalid)?; } pB.copy_from_slice(pB_internal); Ok(()) diff --git a/matter/src/secure_channel/crypto_openssl.rs b/matter/src/secure_channel/crypto_openssl.rs index 631cb6b9..de60fff2 100644 --- a/matter/src/secure_channel/crypto_openssl.rs +++ b/matter/src/secure_channel/crypto_openssl.rs @@ -15,7 +15,7 @@ * limitations under the License. */ -use crate::error::Error; +use crate::error::{Error, ErrorCode}; use byteorder::{ByteOrder, LittleEndian}; use log::error; @@ -158,7 +158,7 @@ impl CryptoSpake2 { let pB_internal = pB_internal.as_slice(); if pB_internal.len() != pB.len() { error!("pB length mismatch"); - return Err(Error::Invalid); + Err(ErrorCode::Invalid)?; } pB.copy_from_slice(pB_internal); Ok(()) diff --git a/matter/src/secure_channel/pake.rs b/matter/src/secure_channel/pake.rs index 1901686c..84c5ba09 100644 --- a/matter/src/secure_channel/pake.rs +++ b/matter/src/secure_channel/pake.rs @@ -23,10 +23,10 @@ use super::{ }; use crate::{ crypto, - error::Error, + error::{Error, ErrorCode}, mdns::{MdnsMgr, ServiceMode}, secure_channel::common::OpCode, - tlv::{self, get_root_node_struct, FromTLV, OctetStr, TLVElement, TLVWriter, TagType, ToTLV}, + tlv::{self, get_root_node_struct, FromTLV, OctetStr, TLVWriter, TagType, ToTLV}, transport::{ exchange::ExchangeCtx, network::Address, @@ -176,7 +176,7 @@ impl PakeState { if let PakeState::InProgress(s) = new { Ok(s) } else { - Err(Error::InvalidSignature) + Err(ErrorCode::InvalidSignature.into()) } } @@ -187,7 +187,7 @@ impl PakeState { fn take_sess_data(&mut self, exch_ctx: &ExchangeCtx) -> Result { let sd = self.take()?; if sd.exch_id != exch_ctx.exch.get_id() || sd.peer_addr != exch_ctx.sess.get_peer_addr() { - Err(Error::InvalidState) + Err(ErrorCode::InvalidState.into()) } else { Ok(sd) } @@ -240,10 +240,10 @@ impl Pake { let clone_data = if status_code == SCStatusCodes::SessionEstablishmentSuccess { // Get the keys - let ke = ke.ok_or(Error::Invalid)?; + let ke = ke.ok_or(ErrorCode::Invalid)?; let mut session_keys: [u8; 48] = [0; 48]; crypto::hkdf_sha256(&[], ke, &SPAKE2_SESSION_KEYS_INFO, &mut session_keys) - .map_err(|_x| Error::NoSpace)?; + .map_err(|_x| ErrorCode::NoSpace)?; // Create a session let data = sd.spake2p.get_app_data(); @@ -314,7 +314,7 @@ impl Pake { let a = PBKDFParamReq::from_tlv(&root)?; if a.passcode_id != 0 { error!("Can't yet handle passcode_id != 0"); - return Err(Error::Invalid); + Err(ErrorCode::Invalid)?; } let mut our_random: [u8; 32] = [0; 32]; diff --git a/matter/src/secure_channel/spake2p.rs b/matter/src/secure_channel/spake2p.rs index 8a4b794c..9be2d4de 100644 --- a/matter/src/secure_channel/spake2p.rs +++ b/matter/src/secure_channel/spake2p.rs @@ -25,7 +25,7 @@ use subtle::ConstantTimeEq; use crate::{ crypto::{pbkdf2_hmac, Sha256}, - error::Error, + error::{Error, ErrorCode}, }; use super::{common::SCStatusCodes, crypto::CryptoSpake2}; @@ -198,7 +198,7 @@ impl Spake2P { #[allow(non_snake_case)] pub fn handle_pA(&mut self, pA: &[u8], pB: &mut [u8], cB: &mut [u8]) -> Result<(), Error> { if self.mode != Spake2Mode::Verifier(Spake2VerifierState::Init) { - return Err(Error::InvalidState); + Err(ErrorCode::InvalidState)?; } if let Some(crypto_spake2) = &mut self.crypto_spake2 { @@ -251,13 +251,13 @@ impl Spake2P { if ke_internal.len() == Ke.len() { Ke.copy_from_slice(ke_internal); } else { - return Err(Error::NoSpace); + Err(ErrorCode::NoSpace)?; } // Step 2: KcA || KcB = KDF(nil, Ka, "ConfirmationKeys") let mut KcAKcB: [u8; 32] = [0; 32]; crypto::hkdf_sha256(&[], Ka, &SPAKE2P_KEY_CONFIRM_INFO, &mut KcAKcB) - .map_err(|_x| Error::NoSpace)?; + .map_err(|_x| ErrorCode::NoSpace)?; let KcA = &KcAKcB[0..(KcAKcB.len() / 2)]; let KcB = &KcAKcB[(KcAKcB.len() / 2)..]; diff --git a/matter/src/secure_channel/status_report.rs b/matter/src/secure_channel/status_report.rs index 477bcfae..2f6aed13 100644 --- a/matter/src/secure_channel/status_report.rs +++ b/matter/src/secure_channel/status_report.rs @@ -46,6 +46,7 @@ pub fn create_status_report( proto_code: u16, proto_data: Option<&[u8]>, ) -> Result<(), Error> { + proto_tx.reset(); proto_tx.set_proto_id(PROTO_ID_SECURE_CHANNEL); proto_tx.set_proto_opcode(OpCode::StatusReport as u8); let wb = proto_tx.get_writebuf()?; diff --git a/matter/src/tlv/parser.rs b/matter/src/tlv/parser.rs index f8b9716c..b740f5d0 100644 --- a/matter/src/tlv/parser.rs +++ b/matter/src/tlv/parser.rs @@ -15,7 +15,7 @@ * limitations under the License. */ -use crate::error::Error; +use crate::error::{Error, ErrorCode}; use byteorder::{ByteOrder, LittleEndian}; use core::fmt; @@ -284,7 +284,7 @@ fn read_length_value<'a>( // We'll consume the current offset (len) + the entire string if length + size_of_length_field > t.left { // Return Error - Err(Error::NoSpace) + Err(ErrorCode::NoSpace.into()) } else { Ok(( // return the additional size only @@ -390,14 +390,14 @@ impl<'a> TLVElement<'a> { pub fn i8(&self) -> Result { match self.element_type { ElementType::S8(a) => Ok(a), - _ => Err(Error::TLVTypeMismatch), + _ => Err(ErrorCode::TLVTypeMismatch.into()), } } pub fn u8(&self) -> Result { match self.element_type { ElementType::U8(a) => Ok(a), - _ => Err(Error::TLVTypeMismatch), + _ => Err(ErrorCode::TLVTypeMismatch.into()), } } @@ -405,7 +405,7 @@ impl<'a> TLVElement<'a> { match self.element_type { ElementType::U8(a) => Ok(a.into()), ElementType::U16(a) => Ok(a), - _ => Err(Error::TLVTypeMismatch), + _ => Err(ErrorCode::TLVTypeMismatch.into()), } } @@ -414,7 +414,7 @@ impl<'a> TLVElement<'a> { ElementType::U8(a) => Ok(a.into()), ElementType::U16(a) => Ok(a.into()), ElementType::U32(a) => Ok(a), - _ => Err(Error::TLVTypeMismatch), + _ => Err(ErrorCode::TLVTypeMismatch.into()), } } @@ -424,7 +424,7 @@ impl<'a> TLVElement<'a> { ElementType::U16(a) => Ok(a.into()), ElementType::U32(a) => Ok(a.into()), ElementType::U64(a) => Ok(a), - _ => Err(Error::TLVTypeMismatch), + _ => Err(ErrorCode::TLVTypeMismatch.into()), } } @@ -434,7 +434,7 @@ impl<'a> TLVElement<'a> { | ElementType::Utf8l(s) | ElementType::Str16l(s) | ElementType::Utf16l(s) => Ok(s), - _ => Err(Error::TLVTypeMismatch), + _ => Err(ErrorCode::TLVTypeMismatch.into()), } } @@ -444,9 +444,9 @@ impl<'a> TLVElement<'a> { | ElementType::Utf8l(s) | ElementType::Str16l(s) | ElementType::Utf16l(s) => { - Ok(core::str::from_utf8(s).map_err(|_| Error::InvalidData)?) + Ok(core::str::from_utf8(s).map_err(|_| Error::from(ErrorCode::InvalidData))?) } - _ => Err(Error::TLVTypeMismatch), + _ => Err(ErrorCode::TLVTypeMismatch.into()), } } @@ -454,48 +454,48 @@ impl<'a> TLVElement<'a> { match self.element_type { ElementType::False => Ok(false), ElementType::True => Ok(true), - _ => Err(Error::TLVTypeMismatch), + _ => Err(ErrorCode::TLVTypeMismatch.into()), } } pub fn null(&self) -> Result<(), Error> { match self.element_type { ElementType::Null => Ok(()), - _ => Err(Error::TLVTypeMismatch), + _ => Err(ErrorCode::TLVTypeMismatch.into()), } } pub fn confirm_struct(&self) -> Result, Error> { match self.element_type { ElementType::Struct(_) => Ok(*self), - _ => Err(Error::TLVTypeMismatch), + _ => Err(ErrorCode::TLVTypeMismatch.into()), } } pub fn confirm_array(&self) -> Result, Error> { match self.element_type { ElementType::Array(_) => Ok(*self), - _ => Err(Error::TLVTypeMismatch), + _ => Err(ErrorCode::TLVTypeMismatch.into()), } } pub fn confirm_list(&self) -> Result, Error> { match self.element_type { ElementType::List(_) => Ok(*self), - _ => Err(Error::TLVTypeMismatch), + _ => Err(ErrorCode::TLVTypeMismatch.into()), } } pub fn find_tag(&self, tag: u32) -> Result, Error> { let match_tag: TagType = TagType::Context(tag as u8); - let iter = self.enter().ok_or(Error::TLVTypeMismatch)?; + let iter = self.enter().ok_or(ErrorCode::TLVTypeMismatch)?; for a in iter { if match_tag == a.tag_type { return Ok(a); } } - Err(Error::NoTagFound) + Err(ErrorCode::NoTagFound.into()) } pub fn get_tag(&self) -> TagType { @@ -721,14 +721,17 @@ impl<'a> Iterator for TLVContainerIterator<'a> { } pub fn get_root_node(b: &[u8]) -> Result { - TLVList::new(b).iter().next().ok_or(Error::InvalidData) + Ok(TLVList::new(b) + .iter() + .next() + .ok_or(ErrorCode::InvalidData)?) } pub fn get_root_node_struct(b: &[u8]) -> Result { TLVList::new(b) .iter() .next() - .ok_or(Error::InvalidData)? + .ok_or(ErrorCode::InvalidData)? .confirm_struct() } @@ -736,7 +739,7 @@ pub fn get_root_node_list(b: &[u8]) -> Result { TLVList::new(b) .iter() .next() - .ok_or(Error::InvalidData)? + .ok_or(ErrorCode::InvalidData)? .confirm_list() } @@ -802,7 +805,7 @@ mod tests { get_root_node_list, get_root_node_struct, ElementType, Pointer, TLVElement, TLVList, TagType, }; - use crate::error::Error; + use crate::error::ErrorCode; #[test] fn test_short_length_tag() { @@ -1146,7 +1149,10 @@ mod tests { element_type: ElementType::U32(1), } ); - assert_eq!(cmd_path.find_tag(3), Err(Error::NoTagFound)); + assert_eq!( + cmd_path.find_tag(3).map_err(|e| e.code()), + Err(ErrorCode::NoTagFound) + ); // This is the variable of the invoke command assert_eq!( diff --git a/matter/src/tlv/traits.rs b/matter/src/tlv/traits.rs index 0311cb39..3fced12d 100644 --- a/matter/src/tlv/traits.rs +++ b/matter/src/tlv/traits.rs @@ -16,7 +16,7 @@ */ use super::{ElementType, TLVContainerIterator, TLVElement, TLVWriter, TagType}; -use crate::error::Error; +use crate::error::{Error, ErrorCode}; use core::fmt::Debug; use core::slice::Iter; use log::error; @@ -31,7 +31,7 @@ pub trait FromTLV<'a> { where Self: Sized, { - Err(Error::TLVNotFound) + Err(ErrorCode::TLVNotFound.into()) } } @@ -45,7 +45,8 @@ impl<'a, T: FromTLV<'a> + Default, const N: usize> FromTLV<'a> for [T; N] { let mut a = heapless::Vec::::new(); if let Some(tlv_iter) = t.enter() { for element in tlv_iter { - a.push(T::from_tlv(&element)?).map_err(|_| Error::NoSpace)?; + a.push(T::from_tlv(&element)?) + .map_err(|_| ErrorCode::NoSpace)?; } } @@ -53,10 +54,10 @@ impl<'a, T: FromTLV<'a> + Default, const N: usize> FromTLV<'a> for [T; N] { // implementation on top of heapless::Vec (to avoid requiring Copy) // Not sure why we actually need that yet, but without it unit tests fail while a.len() < N { - a.push(Default::default()).map_err(|_| Error::NoSpace)?; + a.push(Default::default()).map_err(|_| ErrorCode::NoSpace)?; } - a.into_array().map_err(|_| Error::Invalid) + a.into_array().map_err(|_| ErrorCode::Invalid.into()) } } @@ -131,7 +132,7 @@ impl<'a> UtfStr<'a> { } pub fn as_str(&self) -> Result<&str, Error> { - core::str::from_utf8(self.0).map_err(|_| Error::Invalid) + core::str::from_utf8(self.0).map_err(|_| ErrorCode::Invalid.into()) } } @@ -172,7 +173,7 @@ impl<'a> ToTLV for OctetStr<'a> { /// Implements the Owned version of Octet String impl FromTLV<'_> for heapless::Vec { fn from_tlv(t: &TLVElement) -> Result, Error> { - heapless::Vec::from_slice(t.slice()?).map_err(|_| Error::NoSpace) + heapless::Vec::from_slice(t.slice()?).map_err(|_| ErrorCode::NoSpace.into()) } } @@ -189,7 +190,7 @@ impl FromTLV<'_> for heapless::String { string .push_str(core::str::from_utf8(t.slice()?)?) - .map_err(|_| Error::NoSpace)?; + .map_err(|_| ErrorCode::NoSpace)?; Ok(string) } @@ -411,7 +412,7 @@ impl<'a> ToTLV for TLVElement<'a> { ElementType::EndCnt => tw.end_container(), _ => { error!("ToTLV Not supported"); - Err(Error::Invalid) + Err(ErrorCode::Invalid.into()) } } } @@ -419,7 +420,7 @@ impl<'a> ToTLV for TLVElement<'a> { #[cfg(test)] mod tests { - use super::{FromTLV, OctetStr, TLVElement, TLVWriter, TagType, ToTLV}; + use super::{FromTLV, OctetStr, TLVWriter, TagType, ToTLV}; use crate::{error::Error, tlv::TLVList, utils::writebuf::WriteBuf}; use matter_macro_derive::{FromTLV, ToTLV}; diff --git a/matter/src/tlv/writer.rs b/matter/src/tlv/writer.rs index 1db84210..45c60c97 100644 --- a/matter/src/tlv/writer.rs +++ b/matter/src/tlv/writer.rs @@ -164,7 +164,7 @@ impl<'a, 'b> TLVWriter<'a, 'b> { pub fn str8(&mut self, tag_type: TagType, data: &[u8]) -> Result<(), Error> { if data.len() > 256 { error!("use str16() instead"); - return Err(Error::Invalid); + return Err(ErrorCode::Invalid.into()); } self.put_control_tag(tag_type, WriteElementType::Str8l)?; self.buf.le_u8(data.len() as u8)?; diff --git a/matter/src/transport/exchange.rs b/matter/src/transport/exchange.rs index 5d7a79c8..5a9bbcf7 100644 --- a/matter/src/transport/exchange.rs +++ b/matter/src/transport/exchange.rs @@ -21,11 +21,10 @@ use core::fmt; use core::time::Duration; use log::{error, info, trace}; -use crate::error::Error; +use crate::error::{Error, ErrorCode}; use crate::interaction_model::core::{ResumeReadReq, ResumeSubscribeReq}; use crate::secure_channel; use crate::secure_channel::case::CaseSession; -use crate::tlv::print_tlv_list; use crate::utils::epoch::Epoch; use crate::utils::rand::Rand; @@ -227,7 +226,8 @@ impl Exchange { tx.get_proto_id(), tx.get_proto_opcode(), ); - print_tlv_list(tx.as_slice()); + + //print_tlv_list(tx.as_slice()); tx.proto.exch_id = self.id; if self.role == Role::Initiator { @@ -317,10 +317,10 @@ impl ExchangeMgr { info!("Creating new exchange"); let e = Exchange::new(id, sess_idx, role); if exchanges.insert(id, e).is_err() { - return Err(Error::NoSpace); + Err(ErrorCode::NoSpace)?; } } else { - return Err(Error::NoSpace); + Err(ErrorCode::NoSpace)?; } } @@ -330,11 +330,11 @@ impl ExchangeMgr { if result.get_role() == role && sess_idx == result.sess_idx { Ok(result) } else { - Err(Error::NoExchange) + Err(ErrorCode::NoExchange.into()) } } else { error!("This should never happen"); - Err(Error::NoSpace) + Err(ErrorCode::NoSpace.into()) } } @@ -375,7 +375,7 @@ impl ExchangeMgr { pub fn send(&mut self, exch_id: u16, tx: &mut Packet) -> Result { let exchange = - ExchangeMgr::_get_with_id(&mut self.exchanges, exch_id).ok_or(Error::NoExchange)?; + ExchangeMgr::_get_with_id(&mut self.exchanges, exch_id).ok_or(ErrorCode::NoExchange)?; let mut session = self.sess_mgr.get_session_handle(exchange.sess_idx); exchange.send(tx, &mut session) } @@ -474,7 +474,7 @@ impl fmt::Display for ExchangeMgr { #[allow(clippy::bool_assert_comparison)] mod tests { use crate::{ - error::Error, + error::ErrorCode, transport::{ network::Address, session::{CloneData, SessionMode}, @@ -532,9 +532,12 @@ mod tests { let clone_data = get_clone_data(peer_sess_id, local_sess_id); match mgr.add_session(&clone_data) { Ok(s) => assert_eq!(peer_sess_id, s.get_peer_sess_id()), - Err(Error::NoSpace) => break, - _ => { - panic!("Couldn't, create session"); + Err(e) => { + if e.code() == ErrorCode::NoSpace { + break; + } else { + panic!("Could not create sessions"); + } } } local_sess_id += 1; @@ -576,7 +579,10 @@ mod tests { for i in 1..(MAX_SESSIONS + 1) { // Now purposefully overflow the sessions by adding another session let result = mgr.add_session(&get_clone_data(new_peer_sess_id, new_local_sess_id)); - assert!(matches!(result, Err(Error::NoSpace))); + assert!(matches!( + result.map_err(|e| e.code()), + Err(ErrorCode::NoSpace) + )); let mut buf = [0; MAX_TX_BUF_SIZE]; let tx = &mut Packet::new_tx(&mut buf); diff --git a/matter/src/transport/mgr.rs b/matter/src/transport/mgr.rs index 0db63904..331c3625 100644 --- a/matter/src/transport/mgr.rs +++ b/matter/src/transport/mgr.rs @@ -109,14 +109,18 @@ impl<'r, 'a, 'p> RecvCompletion<'r, 'a, 'p> { } } Ok(None) => (RecvState::Ack, None), - Err(Error::Duplicate) => (RecvState::Ack, None), - Err(Error::NoSpace) => (RecvState::EvictSession, None), - Err(err) => Err(err)?, + Err(e) => match e.code() { + ErrorCode::Duplicate => (RecvState::Ack, None), + ErrorCode::NoSpace => (RecvState::EvictSession, None), + _ => Err(e)?, + }, }, RecvState::AddSession(clone_data) => match self.mgr.exch_mgr.add_session(&clone_data) { Ok(_) => (RecvState::Ack, None), - Err(Error::NoSpace) => (RecvState::EvictSession2(clone_data), None), - Err(err) => Err(err)?, + Err(e) => match e.code() { + ErrorCode::NoSpace => (RecvState::EvictSession2(clone_data), None), + _ => Err(e)?, + }, }, RecvState::EvictSession => { if self.mgr.exch_mgr.evict_session(&mut self.tx)? { diff --git a/matter/src/transport/mrp.rs b/matter/src/transport/mrp.rs index 2213a524..2d046bf5 100644 --- a/matter/src/transport/mrp.rs +++ b/matter/src/transport/mrp.rs @@ -59,7 +59,7 @@ impl AckEntry { ack_timeout, }) } else { - Err(Error::Invalid) + Err(ErrorCode::Invalid.into()) } } @@ -120,7 +120,7 @@ impl ReliableMessage { if self.retrans.is_some() { // This indicates there was some existing entry for same sess-id/exch-id, which shouldnt happen error!("Previous retrans entry for this exchange already exists"); - return Err(Error::Invalid); + Err(ErrorCode::Invalid)?; } self.retrans = Some(RetransEntry::new(proto_tx.plain.ctr)); @@ -135,7 +135,7 @@ impl ReliableMessage { pub fn recv(&mut self, proto_rx: &Packet, epoch: Epoch) -> Result<(), Error> { if proto_rx.proto.is_ack() { // Handle received Acks - let ack_msg_ctr = proto_rx.proto.get_ack_msg_ctr().ok_or(Error::Invalid)?; + let ack_msg_ctr = proto_rx.proto.get_ack_msg_ctr().ok_or(ErrorCode::Invalid)?; if let Some(entry) = &self.retrans { if entry.get_msg_ctr() != ack_msg_ctr { // TODO: XXX Fix this @@ -150,7 +150,7 @@ impl ReliableMessage { // This indicates there was some existing entry for same sess-id/exch-id, which shouldnt happen // TODO: As per the spec if this happens, we need to send out the previous ACK and note this new ACK error!("Previous ACK entry for this exchange already exists"); - return Err(Error::Invalid); + Err(ErrorCode::Invalid)?; } self.ack = Some(AckEntry::new(proto_rx.plain.ctr, epoch)?); diff --git a/matter/src/transport/packet.rs b/matter/src/transport/packet.rs index d56485f1..3e7e9c75 100644 --- a/matter/src/transport/packet.rs +++ b/matter/src/transport/packet.rs @@ -18,7 +18,7 @@ use log::error; use crate::{ - error::Error, + error::{Error, ErrorCode}, utils::{parsebuf::ParseBuf, writebuf::WriteBuf}, }; @@ -109,7 +109,7 @@ impl<'a> Packet<'a> { if let Direction::Rx(pbuf, _) = &mut self.data { Ok(pbuf) } else { - Err(Error::Invalid) + Err(ErrorCode::Invalid.into()) } } @@ -117,7 +117,7 @@ impl<'a> Packet<'a> { if let Direction::Tx(wbuf) = &mut self.data { Ok(wbuf) } else { - Err(Error::Invalid) + Err(ErrorCode::Invalid.into()) } } @@ -158,10 +158,10 @@ impl<'a> Packet<'a> { .decrypt_and_decode(&self.plain, pb, peer_nodeid, dec_key) } else { error!("Invalid state for proto_decode"); - Err(Error::InvalidState) + Err(ErrorCode::InvalidState.into()) } } - _ => Err(Error::InvalidState), + _ => Err(ErrorCode::InvalidState.into()), } } @@ -171,7 +171,7 @@ impl<'a> Packet<'a> { RxState::Uninit => Ok(false), _ => Ok(true), }, - _ => Err(Error::InvalidState), + _ => Err(ErrorCode::InvalidState.into()), } } @@ -183,10 +183,10 @@ impl<'a> Packet<'a> { self.plain.decode(pb) } else { error!("Invalid state for plain_decode"); - Err(Error::InvalidState) + Err(ErrorCode::InvalidState.into()) } } - _ => Err(Error::InvalidState), + _ => Err(ErrorCode::InvalidState.into()), } } } diff --git a/matter/src/transport/plain_hdr.rs b/matter/src/transport/plain_hdr.rs index e51ddaf0..e5a9b24e 100644 --- a/matter/src/transport/plain_hdr.rs +++ b/matter/src/transport/plain_hdr.rs @@ -65,7 +65,7 @@ impl PlainHdr { impl PlainHdr { // it will have an additional 'message length' field first pub fn decode(&mut self, msg: &mut ParseBuf) -> Result<(), Error> { - self.flags = MsgFlags::from_bits(msg.le_u8()?).ok_or(Error::Invalid)?; + self.flags = MsgFlags::from_bits(msg.le_u8()?).ok_or(ErrorCode::Invalid)?; self.sess_id = msg.le_u16()?; let _sec_flags = msg.le_u8()?; self.sess_type = if self.sess_id != 0 { diff --git a/matter/src/transport/proto_hdr.rs b/matter/src/transport/proto_hdr.rs index fd392bd2..d7f92fb4 100644 --- a/matter/src/transport/proto_hdr.rs +++ b/matter/src/transport/proto_hdr.rs @@ -105,7 +105,7 @@ impl ProtoHdr { decrypt_in_place(plain_hdr.ctr, peer_nodeid, parsebuf, d)?; } - self.exch_flags = ExchFlags::from_bits(parsebuf.le_u8()?).ok_or(Error::Invalid)?; + self.exch_flags = ExchFlags::from_bits(parsebuf.le_u8()?).ok_or(ErrorCode::Invalid)?; self.proto_opcode = parsebuf.le_u8()?; self.exch_id = parsebuf.le_u16()?; self.proto_id = parsebuf.le_u16()?; @@ -128,10 +128,10 @@ impl ProtoHdr { resp_buf.le_u16(self.exch_id)?; resp_buf.le_u16(self.proto_id)?; if self.is_vendor() { - resp_buf.le_u16(self.proto_vendor_id.ok_or(Error::Invalid)?)?; + resp_buf.le_u16(self.proto_vendor_id.ok_or(ErrorCode::Invalid)?)?; } if self.is_ack() { - resp_buf.le_u32(self.ack_msg_ctr.ok_or(Error::Invalid)?)?; + resp_buf.le_u32(self.ack_msg_ctr.ok_or(ErrorCode::Invalid)?)?; } Ok(()) } @@ -216,7 +216,7 @@ fn decrypt_in_place( // If so, we need to handle it cleanly here. aad.copy_from_slice(parsed_slice); } else { - return Err(Error::InvalidAAD); + Err(ErrorCode::InvalidAAD)?; } // IV: diff --git a/matter/src/transport/session.rs b/matter/src/transport/session.rs index 95597e2f..1135f05d 100644 --- a/matter/src/transport/session.rs +++ b/matter/src/transport/session.rs @@ -410,7 +410,7 @@ impl SessionMgr { self.sessions[index] = Some(session); Ok(index) } else { - Err(Error::NoSpace) + Err(ErrorCode::NoSpace.into()) } } @@ -465,7 +465,7 @@ impl SessionMgr { info!("Creating new session"); self.add(peer_addr, peer_nodeid) } else { - Err(Error::NotFound) + Err(ErrorCode::NotFound.into()) } } @@ -484,14 +484,14 @@ impl SessionMgr { let duplicate = session.rx_ctr_state.recv(rx.plain.ctr, is_encrypted); if duplicate { info!("Dropping duplicate packet"); - Err(Error::Duplicate) + Err(ErrorCode::Duplicate.into()) } else { Ok(sess_index) } } pub fn decode(&mut self, rx: &mut Packet) -> Result<(), Error> { - // let network = self.network.as_ref().ok_or(Error::NoNetworkInterface)?; + // let network = self.network.as_ref().ok_or(ErrorCode::NoNetworkInterface)?; // let (len, src) = network.recv(rx.as_borrow_slice()).await?; // rx.get_parsebuf()?.set_len(len); @@ -507,7 +507,7 @@ impl SessionMgr { pub fn send(&mut self, sess_idx: usize, tx: &mut Packet) -> Result<(), Error> { self.sessions[sess_idx] .as_mut() - .ok_or(Error::NoSession)? + .ok_or(ErrorCode::NoSession)? .do_send(self.epoch, tx)?; // let network = self.network.as_ref().ok_or(Error::NoNetworkInterface)?; diff --git a/matter/src/transport/udp.rs b/matter/src/transport/udp.rs index b3c4c484..b29ca05a 100644 --- a/matter/src/transport/udp.rs +++ b/matter/src/transport/udp.rs @@ -53,7 +53,7 @@ impl UdpListener { let (size, addr) = self.socket.recv_from(in_buf).await.map_err(|e| { warn!("Error on the network: {:?}", e); - Error::Network + ErrorCode::Network })?; info!("Got packet: {:?} from addr {:?}", &in_buf[..size], addr); @@ -66,7 +66,7 @@ impl UdpListener { Address::Udp(addr) => { let len = self.socket.send_to(out_buf, addr).await.map_err(|e| { warn!("Error on the network: {:?}", e); - Error::Network + ErrorCode::Network })?; info!( diff --git a/matter/src/utils/parsebuf.rs b/matter/src/utils/parsebuf.rs index d6a8b9a9..549e022b 100644 --- a/matter/src/utils/parsebuf.rs +++ b/matter/src/utils/parsebuf.rs @@ -65,7 +65,7 @@ impl<'a> ParseBuf<'a> { self.left -= size; return Ok(tail); } - Err(Error::TruncatedPacket) + Err(ErrorCode::TruncatedPacket.into()) } fn advance(&mut self, len: usize) { @@ -82,7 +82,7 @@ impl<'a> ParseBuf<'a> { self.advance(size); return Ok(data); } - Err(Error::TruncatedPacket) + Err(ErrorCode::TruncatedPacket.into()) } pub fn le_u8(&mut self) -> Result { diff --git a/matter/src/utils/writebuf.rs b/matter/src/utils/writebuf.rs index 3adafe2f..21a51e28 100644 --- a/matter/src/utils/writebuf.rs +++ b/matter/src/utils/writebuf.rs @@ -70,9 +70,9 @@ impl<'a> WriteBuf<'a> { pub fn reserve(&mut self, reserve: usize) -> Result<(), Error> { if self.end != 0 || self.start != 0 || self.buf_size != self.buf.len() { - Err(Error::Invalid) + Err(ErrorCode::Invalid.into()) } else if reserve > self.buf_size { - Err(Error::NoSpace) + Err(ErrorCode::NoSpace.into()) } else { self.start = reserve; self.end = reserve; @@ -85,7 +85,7 @@ impl<'a> WriteBuf<'a> { self.buf_size -= with; Ok(()) } else { - Err(Error::NoSpace) + Err(ErrorCode::NoSpace.into()) } } @@ -94,7 +94,7 @@ impl<'a> WriteBuf<'a> { self.buf_size += by; Ok(()) } else { - Err(Error::NoSpace) + Err(ErrorCode::NoSpace.into()) } } @@ -107,7 +107,7 @@ impl<'a> WriteBuf<'a> { self.start -= size; return Ok(()); } - Err(Error::NoSpace) + Err(ErrorCode::NoSpace.into()) } pub fn prepend(&mut self, src: &[u8]) -> Result<(), Error> { @@ -126,7 +126,7 @@ impl<'a> WriteBuf<'a> { self.end += size; return Ok(()); } - Err(Error::NoSpace) + Err(ErrorCode::NoSpace.into()) } pub fn append(&mut self, src: &[u8]) -> Result<(), Error> { diff --git a/matter/tests/common/echo_cluster.rs b/matter/tests/common/echo_cluster.rs index dd61a0e7..e5caca70 100644 --- a/matter/tests/common/echo_cluster.rs +++ b/matter/tests/common/echo_cluster.rs @@ -27,7 +27,7 @@ use matter::{ Cluster, CmdDataEncoder, CmdDataWriter, CmdDetails, Dataver, Handler, Quality, ATTRIBUTE_LIST, FEATURE_MAP, }, - error::Error, + error::{Error, ErrorCode}, interaction_model::{ core::Transaction, messages::ib::{attr_list_write, ListOperation}, @@ -122,7 +122,7 @@ impl TestChecker { INIT.call_once(|| { G_TEST_CHECKER = Some(Arc::new(Mutex::new(Self::new()))); }); - Ok(G_TEST_CHECKER.as_ref().ok_or(Error::Invalid)?.clone()) + Ok(G_TEST_CHECKER.as_ref().ok_or(ErrorCode::Invalid)?.clone()) } } } @@ -235,7 +235,7 @@ impl EchoCluster { } } - Err(Error::ResourceExhausted) + Err(ErrorCode::ResourceExhausted.into()) } ListOperation::EditItem(index) => { let data = data.u16()?; @@ -243,7 +243,7 @@ impl EchoCluster { tc.write_list[*index as usize] = Some(data); Ok(()) } else { - Err(Error::InvalidAction) + Err(ErrorCode::InvalidAction.into()) } } ListOperation::DeleteItem(index) => { @@ -251,7 +251,7 @@ impl EchoCluster { tc.write_list[*index as usize] = None; Ok(()) } else { - Err(Error::InvalidAction) + Err(ErrorCode::InvalidAction.into()) } } ListOperation::DeleteList => { diff --git a/matter/tests/common/im_engine.rs b/matter/tests/common/im_engine.rs index 86674ecf..e1402823 100644 --- a/matter/tests/common/im_engine.rs +++ b/matter/tests/common/im_engine.rs @@ -216,7 +216,7 @@ impl<'a> ImEngine<'a> { epoch: *self.matter.borrow(), }; let mut rx_buf = [0; MAX_RX_BUF_SIZE]; - let mut tx_buf = [0; 1450]; // For the long read tests to run unchanged + let mut tx_buf = [0; 1440]; // For the long read tests to run unchanged let mut rx = Packet::new_rx(&mut rx_buf); let mut tx = Packet::new_tx(&mut tx_buf); // Create fake rx packet diff --git a/matter_macro_derive/Cargo.toml b/matter_macro_derive/Cargo.toml index f0f38ba4..163ff502 100644 --- a/matter_macro_derive/Cargo.toml +++ b/matter_macro_derive/Cargo.toml @@ -11,3 +11,4 @@ proc-macro = true syn = { version = "1", features = ["extra-traits"]} quote = "1" proc-macro2 = "1" +proc-macro-crate = "1.3" diff --git a/matter_macro_derive/src/lib.rs b/matter_macro_derive/src/lib.rs index a1fc5532..c63eddc4 100644 --- a/matter_macro_derive/src/lib.rs +++ b/matter_macro_derive/src/lib.rs @@ -16,7 +16,7 @@ */ use proc_macro::TokenStream; -use proc_macro2::Span; +use proc_macro2::{Ident, Span}; use quote::{format_ident, quote}; use syn::Lit::{Int, Str}; use syn::NestedMeta::{Lit, Meta}; @@ -106,6 +106,18 @@ fn parse_tag_val(field: &syn::Field) -> Option { None } +fn get_crate_name() -> String { + let found_crate = proc_macro_crate::crate_name("matter-iot").unwrap_or_else(|err| { + eprintln!("Warning: defaulting to `crate` {err}"); + proc_macro_crate::FoundCrate::Itself + }); + + match found_crate { + proc_macro_crate::FoundCrate::Itself => String::from("crate"), + proc_macro_crate::FoundCrate::Name(name) => name, + } +} + /// Generate a ToTlv implementation for a structure fn gen_totlv_for_struct( fields: &syn::FieldsNamed, @@ -187,16 +199,18 @@ fn gen_totlv_for_enum( tag_start += 1; } + let krate = Ident::new(&get_crate_name(), Span::call_site()); + let expanded = quote! { - impl #generics ToTLV for #enum_name #generics { - fn to_tlv(&self, tw: &mut TLVWriter, tag_type: TagType) -> Result<(), Error> { + impl #generics #krate::tlv::ToTLV for #enum_name #generics { + fn to_tlv(&self, tw: &mut #krate::tlv::TLVWriter, tag_type: #krate::tlv::TagType) -> Result<(), #krate::error::Error> { let anchor = tw.get_tail(); if let Err(err) = (|| { tw.start_struct(tag_type)?; match self { #( - Self::#variant_names(c) => { c.to_tlv(tw, TagType::Context(#tags))?; }, + Self::#variant_names(c) => { c.to_tlv(tw, #krate::tlv::TagType::Context(#tags))?; }, )* } tw.end_container() @@ -297,14 +311,16 @@ fn gen_fromtlv_for_struct( } } + let krate = Ident::new(&get_crate_name(), Span::call_site()); + // Currently we don't use find_tag() because the tags come in sequential // order. If ever the tags start coming out of order, we can use find_tag() // instead let expanded = if !tlvargs.unordered { quote! { - impl #generics FromTLV <#lifetime> for #struct_name #generics { - fn from_tlv(t: &TLVElement<#lifetime>) -> Result { - let mut t_iter = t.#datatype ()?.enter().ok_or(Error::Invalid)?; + impl #generics #krate::tlv::FromTLV <#lifetime> for #struct_name #generics { + fn from_tlv(t: &#krate::tlv::TLVElement<#lifetime>) -> Result { + let mut t_iter = t.#datatype ()?.enter().ok_or_else(|| #krate::error::Error::new(#krate::error::ErrorCode::Invalid))?; let mut item = t_iter.next(); #( let #idents = if Some(true) == item.map(|x| x.check_ctx_tag(#tags)) { @@ -324,8 +340,8 @@ fn gen_fromtlv_for_struct( } } else { quote! { - impl #generics FromTLV <#lifetime> for #struct_name #generics { - fn from_tlv(t: &TLVElement<#lifetime>) -> Result { + impl #generics #krate::tlv::FromTLV <#lifetime> for #struct_name #generics { + fn from_tlv(t: &#krate::tlv::TLVElement<#lifetime>) -> Result { #( let #idents = if let Ok(s) = t.find_tag(#tags as u32) { #types::from_tlv(&s) @@ -375,20 +391,22 @@ fn gen_fromtlv_for_enum( tag_start += 1; } + let krate = Ident::new(&get_crate_name(), Span::call_site()); + let expanded = quote! { - impl #generics FromTLV <#lifetime> for #enum_name #generics { - fn from_tlv(t: &TLVElement<#lifetime>) -> Result { - let mut t_iter = t.confirm_struct()?.enter().ok_or(Error::Invalid)?; - let mut item = t_iter.next().ok_or(Error::Invalid)?; + impl #generics #krate::tlv::FromTLV <#lifetime> for #enum_name #generics { + fn from_tlv(t: &#krate::tlv::TLVElement<#lifetime>) -> Result { + let mut t_iter = t.confirm_struct()?.enter().ok_or_else(|| #krate::error::Error::new(#krate::error::ErrorCode::Invalid))?; + let mut item = t_iter.next().ok_or_else(|| Error::new(#krate::error::ErrorCode::Invalid))?; if let TagType::Context(tag) = item.get_tag() { match tag { #( #tags => Ok(Self::#variant_names(#types::from_tlv(&item)?)), )* - _ => Err(Error::Invalid), + _ => Err(#krate::error::Error::new(#krate::error::ErrorCode::Invalid)), } } else { - Err(Error::TLVTypeMismatch) + Err(#krate::error::Error::new(#krate::error::ErrorCode::TLVTypeMismatch)) } } } From 89014ed7f2ff174224ff74f4e4a78b66a43fb93f Mon Sep 17 00:00:00 2001 From: ivmarkov Date: Sat, 29 Apr 2023 17:20:55 +0000 Subject: [PATCH 35/72] Remove heapless::String from QR API --- examples/onoff_light/src/main.rs | 6 +-- matter/src/core.rs | 10 ++-- matter/src/pairing/code.rs | 2 +- matter/src/pairing/mod.rs | 14 +++--- matter/src/pairing/qr.rs | 78 ++++++++++++++++++++++---------- 5 files changed, 69 insertions(+), 41 deletions(-) diff --git a/examples/onoff_light/src/main.rs b/examples/onoff_light/src/main.rs index 604ffce3..3b943058 100644 --- a/examples/onoff_light/src/main.rs +++ b/examples/onoff_light/src/main.rs @@ -54,8 +54,8 @@ fn main() -> Result<(), impl Error> { device_name: "OnOff Light", }; - let mut mdns = matter::mdns::astro::AstroMdns::new()?; - //let mut mdns = matter::mdns::libmdns::LibMdns::new()?; + //let mut mdns = matter::mdns::astro::AstroMdns::new()?; + let mut mdns = matter::mdns::libmdns::LibMdns::new()?; //let mut mdns = matter::mdns::DummyMdns {}; let matter = Matter::new_default(&dev_info, &mut mdns, matter::transport::udp::MATTER_PORT); @@ -77,7 +77,7 @@ fn main() -> Result<(), impl Error> { matter.load_fabrics(data)?; } - matter.start::<4096>( + matter.start( CommissioningData { // TODO: Hard-coded for now verifier: VerifierData::new_with_pw(123456, *matter.borrow()), diff --git a/matter/src/core.rs b/matter/src/core.rs index e2e6b597..17452e3c 100644 --- a/matter/src/core.rs +++ b/matter/src/core.rs @@ -116,19 +116,15 @@ impl<'a> Matter<'a> { self.acl_mgr.borrow_mut().store(buf) } - pub fn start( - &self, - dev_comm: CommissioningData, - buf: &mut [u8], - ) -> Result<(), Error> { + pub fn start(&self, dev_comm: CommissioningData, buf: &mut [u8]) -> Result<(), Error> { let open_comm_window = self.fabric_mgr.borrow().is_empty(); if open_comm_window { - print_pairing_code_and_qr::( + print_pairing_code_and_qr( self.dev_det, &dev_comm, DiscoveryCapabilities::default(), buf, - ); + )?; self.pase_mgr.borrow_mut().enable_pase_session( dev_comm.verifier, diff --git a/matter/src/pairing/code.rs b/matter/src/pairing/code.rs index 16e4feab..91061166 100644 --- a/matter/src/pairing/code.rs +++ b/matter/src/pairing/code.rs @@ -19,7 +19,7 @@ use core::fmt::Write; use super::*; -pub(super) fn compute_pairing_code(comm_data: &CommissioningData) -> heapless::String<32> { +pub fn compute_pairing_code(comm_data: &CommissioningData) -> heapless::String<32> { // 0: no Vendor ID and Product ID present in Manual Pairing Code const VID_PID_PRESENT: u8 = 0; diff --git a/matter/src/pairing/mod.rs b/matter/src/pairing/mod.rs index 96f3105e..2dddce56 100644 --- a/matter/src/pairing/mod.rs +++ b/matter/src/pairing/mod.rs @@ -31,7 +31,7 @@ use crate::{ use self::{ code::{compute_pairing_code, pretty_print_pairing_code}, - qr::{payload_base38_representation, print_qr_code, QrSetupPayload}, + qr::{compute_qr_code, print_qr_code}, }; pub struct DiscoveryCapabilities { @@ -81,19 +81,19 @@ impl DiscoveryCapabilities { } /// Prepares and prints the pairing code and the QR code for easy pairing. -pub fn print_pairing_code_and_qr( +pub fn print_pairing_code_and_qr( dev_det: &BasicInfoConfig, comm_data: &CommissioningData, discovery_capabilities: DiscoveryCapabilities, buf: &mut [u8], -) { +) -> Result<(), Error> { let pairing_code = compute_pairing_code(comm_data); - let qr_code_data = QrSetupPayload::new(dev_det, comm_data, discovery_capabilities); - let data_str = - payload_base38_representation::(&qr_code_data, buf).expect("Failed to encode"); + let qr_code = compute_qr_code(dev_det, comm_data, discovery_capabilities, buf)?; pretty_print_pairing_code(&pairing_code); - print_qr_code(&data_str); + print_qr_code(&qr_code); + + Ok(()) } pub(self) fn passwd_from_comm_data(comm_data: &CommissioningData) -> u32 { diff --git a/matter/src/pairing/qr.rs b/matter/src/pairing/qr.rs index aa264811..3550dfb8 100644 --- a/matter/src/pairing/qr.rs +++ b/matter/src/pairing/qr.rs @@ -253,20 +253,24 @@ pub enum CommissionningFlowType { Custom = 2, } -pub(super) fn payload_base38_representation( +pub(super) fn payload_base38_representation<'a>( payload: &QrSetupPayload, - buf: &mut [u8], -) -> Result, Error> { + buf: &'a mut [u8], +) -> Result<&'a str, Error> { if payload.is_valid() { - let (bits_buf, tlv_buf) = if payload.has_tlv() { - let (bits_buf, tlv_buf) = buf.split_at_mut(buf.len() / 2); + let (str_buf, bits_buf, tlv_buf) = if payload.has_tlv() { + let (str_buf, buf) = buf.split_at_mut(buf.len() / 3 * 2); - (bits_buf, Some(tlv_buf)) + let (bits_buf, tlv_buf) = buf.split_at_mut(buf.len() / 3); + + (str_buf, bits_buf, Some(tlv_buf)) } else { - (buf, None) + let (str_buf, buf) = buf.split_at_mut(buf.len() / 3 * 2); + + (str_buf, buf, None) }; - payload_base38_representation_with_tlv(payload, bits_buf, tlv_buf) + payload_base38_representation_with_tlv(payload, str_buf, bits_buf, tlv_buf) } else { Err(ErrorCode::InvalidArgument.into()) } @@ -315,16 +319,16 @@ fn estimate_struct_overhead(first_field_size: usize) -> usize { first_field_size + 4 + 2 } -pub(super) fn print_qr_code(qr_data: &str) { - info!("QR Code: {}", qr_data); +pub(super) fn print_qr_code(qr_code: &str) { + info!("QR Code: {}", qr_code); #[cfg(feature = "std")] { use qrcode::{render::unicode, QrCode, Version}; - let needed_version = compute_qr_version(qr_data); + let needed_version = compute_qr_version(qr_code); let code = - QrCode::with_version(qr_data, Version::Normal(needed_version), qrcode::EcLevel::M) + QrCode::with_version(qr_code, Version::Normal(needed_version), qrcode::EcLevel::M) .unwrap(); let image = code .render::() @@ -336,6 +340,16 @@ pub(super) fn print_qr_code(qr_data: &str) { } } +pub fn compute_qr_code<'a>( + dev_det: &BasicInfoConfig, + comm_data: &CommissioningData, + discovery_capabilities: DiscoveryCapabilities, + buf: &'a mut [u8], +) -> Result<&'a str, Error> { + let qr_code_data = QrSetupPayload::new(dev_det, comm_data, discovery_capabilities); + payload_base38_representation(&qr_code_data, buf) +} + fn compute_qr_version(qr_data: &str) -> i16 { match qr_data.len() { 0..=38 => 2, @@ -375,11 +389,12 @@ fn populate_bits( Ok(()) } -fn payload_base38_representation_with_tlv( +fn payload_base38_representation_with_tlv<'a>( payload: &QrSetupPayload, + str_buf: &'a mut [u8], bits_buf: &mut [u8], tlv_buf: Option<&mut [u8]>, -) -> Result, Error> { +) -> Result<&'a str, Error> { let tlv_data = if let Some(tlv_buf) = tlv_buf { Some(generate_tlv_from_optional_data(payload, tlv_buf)?) } else { @@ -388,13 +403,30 @@ fn payload_base38_representation_with_tlv( let bits = generate_bit_set(payload, bits_buf, tlv_data)?; - let mut base38_encoded: heapless::String = "MT:".into(); + let prefix = "MT:"; + + if str_buf.len() < prefix.as_bytes().len() { + Err(ErrorCode::NoSpace)?; + } + + str_buf[..prefix.as_bytes().len()].copy_from_slice(prefix.as_bytes()); + + let mut offset = prefix.len(); for c in base38::encode(bits) { - base38_encoded.push(c).map_err(|_| ErrorCode::NoSpace)?; + let mut char_buf = [0; 4]; + let str = c.encode_utf8(&mut char_buf); + + if str_buf.len() - offset < str.as_bytes().len() { + Err(ErrorCode::NoSpace)?; + } + + str_buf[offset..offset + str.as_bytes().len()].copy_from_slice(str.as_bytes()); + + offset += str.as_bytes().len(); } - Ok(base38_encoded) + Ok(core::str::from_utf8(&str_buf[..offset])?) } fn generate_tlv_from_optional_data<'a>( @@ -557,8 +589,8 @@ mod tests { let disc_cap = DiscoveryCapabilities::new(false, true, false); let qr_code_data = QrSetupPayload::new(&dev_det, &comm_data, disc_cap); let mut buf = [0; 1024]; - let data_str = payload_base38_representation::<128>(&qr_code_data, &mut buf) - .expect("Failed to encode"); + let data_str = + payload_base38_representation(&qr_code_data, &mut buf).expect("Failed to encode"); assert_eq!(data_str, QR_CODE) } @@ -580,8 +612,8 @@ mod tests { let disc_cap = DiscoveryCapabilities::new(true, false, false); let qr_code_data = QrSetupPayload::new(&dev_det, &comm_data, disc_cap); let mut buf = [0; 1024]; - let data_str = payload_base38_representation::<128>(&qr_code_data, &mut buf) - .expect("Failed to encode"); + let data_str = + payload_base38_representation(&qr_code_data, &mut buf).expect("Failed to encode"); assert_eq!(data_str, QR_CODE) } @@ -626,8 +658,8 @@ mod tests { .expect("Failed to add optional data"); let mut buf = [0; 1024]; - let data_str = payload_base38_representation::<{ QR_CODE.len() }>(&qr_code_data, &mut buf) - .expect("Failed to encode"); + let data_str = + payload_base38_representation(&qr_code_data, &mut buf).expect("Failed to encode"); assert_eq!(data_str, QR_CODE) } } From 06b0fcd6f5a9b8a150af0f858f55d3f1baf01c23 Mon Sep 17 00:00:00 2001 From: ivmarkov Date: Thu, 4 May 2023 05:26:13 +0000 Subject: [PATCH 36/72] Fix no_std errors --- matter/tests/common/im_engine.rs | 30 ++++++++++------------ matter/tests/common/mod.rs | 7 +++++ matter/tests/data_model/acl_and_dataver.rs | 19 +++++++------- matter/tests/data_model/attribute_lists.rs | 3 ++- matter/tests/data_model/attributes.rs | 17 ++++++------ matter/tests/data_model/commands.rs | 9 ++++--- matter/tests/data_model/long_reads.rs | 7 ++--- matter/tests/data_model/timed_requests.rs | 9 ++++--- matter/tests/interaction_model.rs | 7 ++--- 9 files changed, 60 insertions(+), 48 deletions(-) diff --git a/matter/tests/common/im_engine.rs b/matter/tests/common/im_engine.rs index e1402823..6d398a4f 100644 --- a/matter/tests/common/im_engine.rs +++ b/matter/tests/common/im_engine.rs @@ -44,19 +44,14 @@ use matter::{ transport::packet::Packet, transport::{ exchange::{self, Exchange, ExchangeCtx}, - network::Address, + network::{Address, IpAddr, Ipv4Addr, SocketAddr}, packet::MAX_RX_BUF_SIZE, proto_ctx::ProtoCtx, session::{CaseDetails, CloneData, NocCatIds, SessionMgr, SessionMode}, }, - utils::{ - epoch::{sys_epoch, sys_utc_calendar}, - rand::dummy_rand, - writebuf::WriteBuf, - }, + utils::{rand::dummy_rand, writebuf::WriteBuf}, Matter, }; -use std::net::{Ipv4Addr, SocketAddr}; use super::echo_cluster::EchoCluster; @@ -109,14 +104,17 @@ impl<'a> ImInput<'a> { pub type DmHandler<'a> = handler_chain_type!(OnOffCluster, EchoCluster, DescriptorCluster, EchoCluster | RootEndpointHandler<'a>); pub fn matter(mdns: &mut dyn Mdns) -> Matter<'_> { - Matter::new( - &BASIC_INFO, - mdns, - sys_epoch, - dummy_rand, - sys_utc_calendar, - 5540, - ) + #[cfg(feature = "std")] + use matter::utils::epoch::sys_epoch as epoch; + #[cfg(feature = "std")] + use matter::utils::epoch::sys_utc_calendar as utc_calendar; + + #[cfg(not(feature = "std"))] + use matter::utils::epoch::dummy_epoch as epoch; + #[cfg(not(feature = "std"))] + use matter::utils::epoch::dummy_utc_calendar as utc_calendar; + + Matter::new(&BASIC_INFO, mdns, epoch, dummy_rand, utc_calendar, 5540) } /// An Interaction Model Engine to facilitate easy testing @@ -203,7 +201,7 @@ impl<'a> ImEngine<'a> { 10, 30, Address::Udp(SocketAddr::new( - std::net::IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)), + IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)), 5542, )), SessionMode::Case(CaseDetails::new(1, &input.cat_ids)), diff --git a/matter/tests/common/mod.rs b/matter/tests/common/mod.rs index 0d2cc9c8..96682b98 100644 --- a/matter/tests/common/mod.rs +++ b/matter/tests/common/mod.rs @@ -20,3 +20,10 @@ pub mod commands; pub mod echo_cluster; pub mod handlers; pub mod im_engine; + +pub fn init_env_logger() { + #[cfg(feature = "std")] + { + let _ = env_logger::try_init(); + } +} diff --git a/matter/tests/data_model/acl_and_dataver.rs b/matter/tests/data_model/acl_and_dataver.rs index 535555ba..81f220ac 100644 --- a/matter/tests/data_model/acl_and_dataver.rs +++ b/matter/tests/data_model/acl_and_dataver.rs @@ -36,6 +36,7 @@ use crate::{ attributes::*, echo_cluster::{self, ATTR_WRITE_DEFAULT_VALUE}, im_engine::{matter, ImEngine}, + init_env_logger, }, }; @@ -43,7 +44,7 @@ use crate::{ /// Ensure that wildcard read attributes don't include error response /// and silently drop the data when access is not granted fn wc_read_attribute() { - let _ = env_logger::try_init(); + init_env_logger(); let wc_att1 = GenericPath::new( None, @@ -101,7 +102,7 @@ fn wc_read_attribute() { /// Ensure that exact read attribute includes error response /// when access is not granted fn exact_read_attribute() { - let _ = env_logger::try_init(); + init_env_logger(); let wc_att1 = GenericPath::new( Some(0), @@ -139,7 +140,7 @@ fn exact_read_attribute() { /// Ensure that an write attribute with a wildcard either performs the operation, /// if allowed, or silently drops the request fn wc_write_attribute() { - let _ = env_logger::try_init(); + init_env_logger(); let val0 = 10; let val1 = 20; let attr_data0 = |tag, t: &mut TLVWriter| { @@ -228,7 +229,7 @@ fn wc_write_attribute() { /// Ensure that an write attribute without a wildcard returns an error when the /// ACL disallows the access, and returns success once access is granted fn exact_write_attribute() { - let _ = env_logger::try_init(); + init_env_logger(); let val0 = 10; let attr_data0 = |tag, t: &mut TLVWriter| { let _ = t.u16(tag, val0); @@ -278,7 +279,7 @@ fn exact_write_attribute() { /// ACL disallows the access, and returns success once access is granted to the CAT ID /// The Accessor CAT version is one more than that in the ACL fn exact_write_attribute_noc_cat() { - let _ = env_logger::try_init(); + init_env_logger(); let val0 = 10; let attr_data0 = |tag, t: &mut TLVWriter| { let _ = t.u16(tag, val0); @@ -330,7 +331,7 @@ fn exact_write_attribute_noc_cat() { #[test] /// Ensure that a write attribute with insufficient permissions is rejected fn insufficient_perms_write() { - let _ = env_logger::try_init(); + init_env_logger(); let val0 = 10; let attr_data0 = |tag, t: &mut TLVWriter| { let _ = t.u16(tag, val0); @@ -379,7 +380,7 @@ fn insufficient_perms_write() { /// - Write Attr to ACL Cluster (allowed, this ACL also grants universal access) /// - Write Attr to Echo Cluster again (successful this time) fn write_with_runtime_acl_add() { - let _ = env_logger::try_init(); + init_env_logger(); let peer = 98765; let mut mdns = DummyMdns {}; let matter = matter(&mut mdns); @@ -446,7 +447,7 @@ fn test_read_data_ver() { // 1 Attr Read Requests // - wildcard endpoint, att1 // - 2 responses are expected - let _ = env_logger::try_init(); + init_env_logger(); let peer = 98765; let mut mdns = DummyMdns {}; let matter = matter(&mut mdns); @@ -549,7 +550,7 @@ fn test_write_data_ver() { // 1 Attr Read Requests // - wildcard endpoint, att1 // - 2 responses are expected - let _ = env_logger::try_init(); + init_env_logger(); let peer = 98765; let mut mdns = DummyMdns {}; let matter = matter(&mut mdns); diff --git a/matter/tests/data_model/attribute_lists.rs b/matter/tests/data_model/attribute_lists.rs index ace1f3db..aaa2b635 100644 --- a/matter/tests/data_model/attribute_lists.rs +++ b/matter/tests/data_model/attribute_lists.rs @@ -29,6 +29,7 @@ use matter::{ use crate::common::{ echo_cluster::{self, TestChecker}, im_engine::{matter, ImEngine}, + init_env_logger, }; // Helper for handling Write Attribute sequences @@ -40,7 +41,7 @@ fn attr_list_ops() { let val1: u16 = 15; let tc_handle = TestChecker::get().unwrap(); - let _ = env_logger::try_init(); + init_env_logger(); let delete_item = EncodeValue::Closure(&|tag, t| { let _ = t.null(tag); diff --git a/matter/tests/data_model/attributes.rs b/matter/tests/data_model/attributes.rs index 17e41124..6d1072cd 100644 --- a/matter/tests/data_model/attributes.rs +++ b/matter/tests/data_model/attributes.rs @@ -35,6 +35,7 @@ use crate::{ attributes::*, echo_cluster, im_engine::{matter, ImEngine}, + init_env_logger, }, }; @@ -44,7 +45,7 @@ fn test_read_success() { // - first on endpoint 0, att1 // - second on endpoint 1, att2 // - third on endpoint 1, attcustom a custom attribute - let _ = env_logger::try_init(); + init_env_logger(); let ep0_att1 = GenericPath::new( Some(0), @@ -86,7 +87,7 @@ fn test_read_unsupported_fields() { // - attribute doesn't exist - UnsupportedAttribute // - attribute doesn't exist and endpoint is wildcard - Silently ignore // - attribute doesn't exist and cluster is wildcard - Silently ignore - let _ = env_logger::try_init(); + init_env_logger(); let invalid_endpoint = GenericPath::new( Some(2), @@ -129,7 +130,7 @@ fn test_read_wc_endpoint_all_have_clusters() { // 1 Attr Read Requests // - wildcard endpoint, att1 // - 2 responses are expected - let _ = env_logger::try_init(); + init_env_logger(); let wc_ep_att1 = GenericPath::new( None, @@ -160,7 +161,7 @@ fn test_read_wc_endpoint_only_1_has_cluster() { // 1 Attr Read Requests // - wildcard endpoint, on/off Cluster OnOff Attribute // - 1 response are expected - let _ = env_logger::try_init(); + init_env_logger(); let wc_ep_onoff = GenericPath::new( None, @@ -185,7 +186,7 @@ fn test_read_wc_endpoint_wc_attribute() { // 1 Attr Read Request // - wildcard endpoint, wildcard attribute // - 8 responses are expected, 1+3 attributes on endpoint 0, 1+3 on endpoint 1 - let _ = env_logger::try_init(); + init_env_logger(); let wc_ep_wc_attr = GenericPath::new(None, Some(echo_cluster::ID), None); let input = &[AttrPath::new(&wc_ep_wc_attr)]; @@ -294,7 +295,7 @@ fn test_write_success() { // - second on endpoint 1, AttWrite let val0 = 10; let val1 = 15; - let _ = env_logger::try_init(); + init_env_logger(); let attr_data0 = |tag, t: &mut TLVWriter| { let _ = t.u16(tag, val0); }; @@ -342,7 +343,7 @@ fn test_write_wc_endpoint() { // 1 Attr Write Request // - wildcard endpoint, AttWrite let val0 = 10; - let _ = env_logger::try_init(); + init_env_logger(); let attr_data0 = |tag, t: &mut TLVWriter| { let _ = t.u16(tag, val0); }; @@ -390,7 +391,7 @@ fn test_write_unsupported_fields() { // - attribute doesn't exist and endpoint is wildcard - Silently ignore // - cluster is wildcard - Cluster cannot be wildcard - UnsupportedCluster // - attribute is wildcard - Attribute cannot be wildcard - UnsupportedAttribute - let _ = env_logger::try_init(); + init_env_logger(); let val0 = 50; let attr_data0 = |tag, t: &mut TLVWriter| { diff --git a/matter/tests/data_model/commands.rs b/matter/tests/data_model/commands.rs index 50c1a8a3..a232f269 100644 --- a/matter/tests/data_model/commands.rs +++ b/matter/tests/data_model/commands.rs @@ -21,6 +21,7 @@ use crate::{ commands::*, echo_cluster, im_engine::{matter, ImEngine}, + init_env_logger, }, echo_req, echo_resp, }; @@ -39,7 +40,7 @@ fn test_invoke_cmds_success() { // 2 echo Requests // - one on endpoint 0 with data 5, // - another on endpoint 1 with data 10 - let _ = env_logger::try_init(); + init_env_logger(); let input = &[echo_req!(0, 5), echo_req!(1, 10)]; let expected = &[echo_resp!(0, 10), echo_resp!(1, 30)]; @@ -54,7 +55,7 @@ fn test_invoke_cmds_unsupported_fields() { // - cluster doesn't exist and endpoint is wildcard - UnsupportedCluster // - command doesn't exist - UnsupportedCommand // - command doesn't exist and endpoint is wildcard - UnsupportedCommand - let _ = env_logger::try_init(); + init_env_logger(); let invalid_endpoint = CmdPath::new( Some(2), @@ -105,7 +106,7 @@ fn test_invoke_cmds_unsupported_fields() { fn test_invoke_cmd_wc_endpoint_all_have_clusters() { // 1 echo Request with wildcard endpoint // should generate 2 responses from the echo clusters on both - let _ = env_logger::try_init(); + init_env_logger(); let path = CmdPath::new( None, Some(echo_cluster::ID), @@ -120,7 +121,7 @@ fn test_invoke_cmd_wc_endpoint_all_have_clusters() { fn test_invoke_cmd_wc_endpoint_only_1_has_cluster() { // 1 on command for on/off cluster with wildcard endpoint // should generate 1 response from the on-off cluster - let _ = env_logger::try_init(); + init_env_logger(); let target = CmdPath::new( None, diff --git a/matter/tests/data_model/long_reads.rs b/matter/tests/data_model/long_reads.rs index 693f1dfa..a396cc0d 100644 --- a/matter/tests/data_model/long_reads.rs +++ b/matter/tests/data_model/long_reads.rs @@ -34,7 +34,7 @@ use matter::{ tlv::{self, ElementType, FromTLV, TLVElement, TagType, ToTLV}, transport::{ exchange::{self, Exchange}, - udp::MAX_RX_BUF_SIZE, + packet::MAX_RX_BUF_SIZE, }, Matter, }; @@ -45,6 +45,7 @@ use crate::{ attributes::*, echo_cluster as echo, im_engine::{matter, ImEngine, ImInput}, + init_env_logger, }, }; @@ -251,7 +252,7 @@ fn wildcard_read_resp(part: u8) -> Vec> { #[test] fn test_long_read_success() { // Read the entire attribute database, which requires 2 reads to complete - let _ = env_logger::try_init(); + init_env_logger(); let mut mdns = DummyMdns; let matter = matter(&mut mdns); let mut lr = LongRead::new(&matter); @@ -285,7 +286,7 @@ fn test_long_read_success() { #[test] fn test_long_read_subscription_success() { // Subscribe to the entire attribute database, which requires 2 reads to complete - let _ = env_logger::try_init(); + init_env_logger(); let mut mdns = DummyMdns; let matter = matter(&mut mdns); let mut lr = LongRead::new(&matter); diff --git a/matter/tests/data_model/timed_requests.rs b/matter/tests/data_model/timed_requests.rs index cf5ddbd7..3f441901 100644 --- a/matter/tests/data_model/timed_requests.rs +++ b/matter/tests/data_model/timed_requests.rs @@ -32,6 +32,7 @@ use crate::{ echo_cluster, handlers::{TimedInvResponse, WriteResponse}, im_engine::{matter, ImEngine}, + init_env_logger, }, echo_req, echo_resp, }; @@ -41,7 +42,7 @@ fn test_timed_write_fail_and_success() { // - 1 Timed Attr Write Transaction should fail due to timeout // - 1 Timed Attr Write Transaction should succeed let val0 = 10; - let _ = env_logger::try_init(); + init_env_logger(); let attr_data0 = |tag, t: &mut TLVWriter| { let _ = t.u16(tag, val0); }; @@ -98,7 +99,7 @@ fn test_timed_write_fail_and_success() { #[test] fn test_timed_cmd_success() { // A timed request that works - let _ = env_logger::try_init(); + init_env_logger(); let input = &[echo_req!(0, 5), echo_req!(1, 10)]; let expected = &[echo_resp!(0, 10), echo_resp!(1, 30)]; @@ -115,7 +116,7 @@ fn test_timed_cmd_success() { #[test] fn test_timed_cmd_timeout() { // A timed request that is executed after t imeout - let _ = env_logger::try_init(); + init_env_logger(); let input = &[echo_req!(0, 5), echo_req!(1, 10)]; ImEngine::new_with_timed_commands( @@ -131,7 +132,7 @@ fn test_timed_cmd_timeout() { #[test] fn test_timed_cmd_timedout_mismatch() { // A timed request with timeout mismatch - let _ = env_logger::try_init(); + init_env_logger(); let input = &[echo_req!(0, 5), echo_req!(1, 10)]; ImEngine::new_with_timed_commands( diff --git a/matter/tests/interaction_model.rs b/matter/tests/interaction_model.rs index b73ab46f..5d2c21a6 100644 --- a/matter/tests/interaction_model.rs +++ b/matter/tests/interaction_model.rs @@ -24,6 +24,9 @@ use matter::interaction_model::core::Transaction; use matter::transport::exchange::Exchange; use matter::transport::exchange::ExchangeCtx; use matter::transport::network::Address; +use matter::transport::network::IpAddr; +use matter::transport::network::Ipv4Addr; +use matter::transport::network::SocketAddr; use matter::transport::packet::Packet; use matter::transport::packet::MAX_RX_BUF_SIZE; use matter::transport::packet::MAX_TX_BUF_SIZE; @@ -31,8 +34,6 @@ use matter::transport::proto_ctx::ProtoCtx; use matter::transport::session::SessionMgr; use matter::utils::epoch::dummy_epoch; use matter::utils::rand::dummy_rand; -use std::net::Ipv4Addr; -use std::net::SocketAddr; struct Node { pub endpoint: u16, @@ -95,7 +96,7 @@ fn handle_data(action: OpCode, data_in: &[u8], data_out: &mut [u8]) -> (DataMode .get_or_add( 0, Address::Udp(SocketAddr::new( - std::net::IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)), + IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)), 5542, )), None, From 695869f13a9e03fc88aa15603cd4ac5dc3fbf09e Mon Sep 17 00:00:00 2001 From: ivmarkov Date: Thu, 4 May 2023 05:42:58 +0000 Subject: [PATCH 37/72] Fix compilation errors in crypto --- matter/Cargo.toml | 12 +-- matter/src/cert/mod.rs | 1 - matter/src/crypto/crypto_dummy.rs | 8 +- matter/src/crypto/crypto_mbedtls.rs | 3 +- matter/src/crypto/crypto_openssl.rs | 15 +++- matter/src/crypto/crypto_rustcrypto.rs | 33 ++++--- matter/src/data_model/sdm/noc.rs | 4 +- matter/src/pairing/qr.rs | 1 + matter/src/secure_channel/case.rs | 2 +- matter/src/secure_channel/crypto.rs | 5 +- matter/src/secure_channel/crypto_dummy.rs | 7 +- matter/src/secure_channel/crypto_mbedtls.rs | 7 +- matter/src/secure_channel/crypto_openssl.rs | 9 +- .../src/secure_channel/crypto_rustcrypto.rs | 88 ++++++++++++------- matter/src/secure_channel/pake.rs | 2 +- matter/src/secure_channel/spake2p.rs | 10 ++- matter/src/transport/network.rs | 4 +- 17 files changed, 139 insertions(+), 72 deletions(-) diff --git a/matter/Cargo.toml b/matter/Cargo.toml index 78e09932..e4a901a0 100644 --- a/matter/Cargo.toml +++ b/matter/Cargo.toml @@ -20,10 +20,10 @@ std = ["alloc", "env_logger", "chrono", "rand", "qrcode", "libmdns", "simple-mdn backtrace = [] alloc = [] nightly = [] -crypto_openssl = ["openssl", "foreign-types", "hmac", "sha2"] -crypto_mbedtls = ["mbedtls", "alloc"] +crypto_openssl = ["alloc", "openssl", "foreign-types", "hmac", "sha2"] +crypto_mbedtls = ["alloc", "mbedtls"] crypto_esp_mbedtls = ["esp-idf-sys"] -crypto_rustcrypto = ["sha2", "hmac", "pbkdf2", "hkdf", "aes", "ccm", "p256", "elliptic-curve", "crypto-bigint", "x509-cert"] +crypto_rustcrypto = ["alloc", "sha2", "hmac", "pbkdf2", "hkdf", "aes", "ccm", "p256", "elliptic-curve", "crypto-bigint", "x509-cert", "rand_core"] [dependencies] matter_macro_derive = { path = "../matter_macro_derive" } @@ -56,6 +56,8 @@ smol = { version = "1.3.0", optional = true} openssl = { git = "https://github.com/sfackler/rust-openssl", optional = true } mbedtls = { version = "0.9", optional = true } esp-idf-sys = { version = "0.32", optional = true } + +# rust-crypto foreign-types = { version = "0.3.2", optional = true } sha2 = { version = "0.10", default-features = false, optional = true } hmac = { version = "0.12", optional = true } @@ -66,8 +68,8 @@ ccm = { version = "0.5", default-features = false, features = ["alloc"], optiona p256 = { version = "0.13.0", default-features = false, features = ["arithmetic", "ecdh", "ecdsa"], optional = true } elliptic-curve = { version = "0.13.2", optional = true } crypto-bigint = { version = "0.4", default-features = false, optional = true } -# TODO: requires STD -x509-cert = { version = "0.2.0", default-features = false, features = ["pem", "std"], optional = true } +rand_core = { version = "0.6", default-features = false, optional = true } +x509-cert = { version = "0.2.0", default-features = false, features = ["pem", "std"], optional = true } # TODO: requires `alloc` # to compute the check digit verhoeff = "1" diff --git a/matter/src/cert/mod.rs b/matter/src/cert/mod.rs index d750db56..0ec6dd56 100644 --- a/matter/src/cert/mod.rs +++ b/matter/src/cert/mod.rs @@ -859,7 +859,6 @@ mod tests { #[test] fn test_tlv_conversions() { - let _ = env_logger::try_init(); let test_input: [&[u8]; 3] = [ &test_vectors::NOC1_SUCCESS, &test_vectors::ICAC1_SUCCESS, diff --git a/matter/src/crypto/crypto_dummy.rs b/matter/src/crypto/crypto_dummy.rs index f00cefd8..827b7f3e 100644 --- a/matter/src/crypto/crypto_dummy.rs +++ b/matter/src/crypto/crypto_dummy.rs @@ -17,7 +17,10 @@ use log::error; -use crate::error::{Error, ErrorCode}; +use crate::{ + error::{Error, ErrorCode}, + utils::rand::Rand, +}; pub fn hkdf_sha256(_salt: &[u8], _ikm: &[u8], _info: &[u8], _key: &mut [u8]) -> Result<(), Error> { error!("This API should never get called"); @@ -60,10 +63,11 @@ impl HmacSha256 { } } +#[derive(Debug)] pub struct KeyPair; impl KeyPair { - pub fn new() -> Result { + pub fn new(_rand: Rand) -> Result { Ok(Self) } diff --git a/matter/src/crypto/crypto_mbedtls.rs b/matter/src/crypto/crypto_mbedtls.rs index 3f95d04f..1eb7a884 100644 --- a/matter/src/crypto/crypto_mbedtls.rs +++ b/matter/src/crypto/crypto_mbedtls.rs @@ -35,6 +35,7 @@ use crate::{ // so Crypto doesn't have to depend on Cert cert::{ASN1Writer, CertConsumer}, error::{Error, ErrorCode}, + utils::rand::Rand, }; pub struct HmacSha256 { @@ -65,7 +66,7 @@ pub struct KeyPair { } impl KeyPair { - pub fn new() -> Result { + pub fn new(_rand: Rand) -> Result { let mut ctr_drbg = CtrDrbg::new(Arc::new(OsEntropy::new()), None)?; Ok(Self { key: Pk::generate_ec(&mut ctr_drbg, EcGroupId::SecP256R1)?, diff --git a/matter/src/crypto/crypto_openssl.rs b/matter/src/crypto/crypto_openssl.rs index 5343c528..24fa267f 100644 --- a/matter/src/crypto/crypto_openssl.rs +++ b/matter/src/crypto/crypto_openssl.rs @@ -16,7 +16,9 @@ */ use crate::error::{Error, ErrorCode}; +use crate::utils::rand::Rand; +use alloc::vec; use foreign_types::ForeignTypeRef; use log::error; use openssl::asn1::Asn1Type; @@ -39,6 +41,9 @@ use openssl::x509::{X509NameBuilder, X509ReqBuilder, X509}; // problem while using OpenSSL's Signer // TODO: Use proper OpenSSL method for this use hmac::{Hmac, Mac}; + +extern crate alloc; + pub struct HmacSha256 { ctx: Hmac, } @@ -62,16 +67,18 @@ impl HmacSha256 { } } +#[derive(Debug)] pub enum KeyType { Public(EcKey), Private(EcKey), } +#[derive(Debug)] pub struct KeyPair { key: KeyType, } impl KeyPair { - pub fn new() -> Result { + pub fn new(_rand: Rand) -> Result { let group = EcGroup::from_curve_name(Nid::X9_62_PRIME256V1)?; let key = EcKey::generate(&group)?; Ok(Self { @@ -206,7 +213,7 @@ impl KeyPair { KeyType::Public(key) => key, _ => { error!("Not yet supported"); - Err(ErrorCode::Invalid)?; + return Err(ErrorCode::Invalid.into()); } }; if !sig.verify(&msg, k)? { @@ -293,7 +300,7 @@ pub fn lowlevel_encrypt_aead( aad: &[u8], data: &[u8], tag: &mut [u8], -) -> Result, ErrorStack> { +) -> Result, ErrorStack> { let t = symm::Cipher::aes_128_ccm(); let mut ctx = CipherCtx::new()?; CipherCtxRef::encrypt_init( @@ -329,7 +336,7 @@ pub fn lowlevel_decrypt_aead( aad: &[u8], data: &[u8], tag: &[u8], -) -> Result, ErrorStack> { +) -> Result, ErrorStack> { let t = symm::Cipher::aes_128_ccm(); let mut ctx = CipherCtx::new()?; CipherCtxRef::decrypt_init( diff --git a/matter/src/crypto/crypto_rustcrypto.rs b/matter/src/crypto/crypto_rustcrypto.rs index b9aa3101..6f975cb2 100644 --- a/matter/src/crypto/crypto_rustcrypto.rs +++ b/matter/src/crypto/crypto_rustcrypto.rs @@ -15,9 +15,10 @@ * limitations under the License. */ -use std::convert::{TryFrom, TryInto}; +use core::convert::{TryFrom, TryInto}; use aes::Aes128; +use alloc::vec; use ccm::{ aead::generic_array::GenericArray, consts::{U13, U16}, @@ -39,13 +40,17 @@ use x509_cert::{ spki::{AlgorithmIdentifier, SubjectPublicKeyInfoOwned}, }; -use crate::error::{Error, ErrorCode}; - -use super::CryptoKeyPair; +use crate::{ + error::{Error, ErrorCode}, + secure_channel::crypto_rustcrypto::RandRngCore, + utils::rand::Rand, +}; type HmacSha256I = hmac::Hmac; type AesCcm = Ccm; +extern crate alloc; + #[derive(Clone)] pub struct Sha256 { hasher: sha2::Sha256, @@ -96,18 +101,20 @@ impl HmacSha256 { } } +#[derive(Debug)] pub enum KeyType { Private(SecretKey), Public(PublicKey), } +#[derive(Debug)] pub struct KeyPair { key: KeyType, } impl KeyPair { - pub fn new() -> Result { - let mut rng = rand::thread_rng(); + pub fn new(rand: Rand) -> Result { + let mut rng = RandRngCore(rand); let secret_key = SecretKey::random(&mut rng); Ok(Self { @@ -146,10 +153,8 @@ impl KeyPair { KeyType::Public(_) => Err(ErrorCode::Crypto.into()), } } -} -impl CryptoKeyPair for KeyPair { - fn get_private_key(&self, priv_key: &mut [u8]) -> Result { + pub fn get_private_key(&self, priv_key: &mut [u8]) -> Result { match &self.key { KeyType::Private(key) => { let bytes = key.to_bytes(); @@ -161,7 +166,7 @@ impl CryptoKeyPair for KeyPair { KeyType::Public(_) => Err(ErrorCode::Crypto.into()), } } - fn get_csr<'a>(&self, out_csr: &'a mut [u8]) -> Result<&'a [u8], Error> { + pub fn get_csr<'a>(&self, out_csr: &'a mut [u8]) -> Result<&'a [u8], Error> { use p256::ecdsa::signature::Signer; let subject = RdnSequence(vec![x509_cert::name::RelativeDistinguishedName( @@ -224,14 +229,14 @@ impl CryptoKeyPair for KeyPair { Ok(a) } - fn get_public_key(&self, pub_key: &mut [u8]) -> Result { + pub fn get_public_key(&self, pub_key: &mut [u8]) -> Result { let point = self.public_key_point().to_encoded_point(false); let bytes = point.as_bytes(); let len = bytes.len(); pub_key[..len].copy_from_slice(bytes); Ok(len) } - fn derive_secret(self, peer_pub_key: &[u8], secret: &mut [u8]) -> Result { + pub fn derive_secret(self, peer_pub_key: &[u8], secret: &mut [u8]) -> Result { let encoded_point = EncodedPoint::from_bytes(peer_pub_key).unwrap(); let peer_pubkey = PublicKey::from_encoded_point(&encoded_point).unwrap(); let private_key = self.private_key()?; @@ -247,7 +252,7 @@ impl CryptoKeyPair for KeyPair { Ok(len) } - fn sign_msg(&self, msg: &[u8], signature: &mut [u8]) -> Result { + pub fn sign_msg(&self, msg: &[u8], signature: &mut [u8]) -> Result { use p256::ecdsa::signature::Signer; if signature.len() < super::EC_SIGNATURE_LEN_BYTES { @@ -266,7 +271,7 @@ impl CryptoKeyPair for KeyPair { KeyType::Public(_) => todo!(), } } - fn verify_msg(&self, msg: &[u8], signature: &[u8]) -> Result<(), Error> { + pub fn verify_msg(&self, msg: &[u8], signature: &[u8]) -> Result<(), Error> { use p256::ecdsa::signature::Verifier; let verifying_key = VerifyingKey::from_affine(self.public_key_point()).unwrap(); diff --git a/matter/src/data_model/sdm/noc.rs b/matter/src/data_model/sdm/noc.rs index 6182c0c0..f7346cc3 100644 --- a/matter/src/data_model/sdm/noc.rs +++ b/matter/src/data_model/sdm/noc.rs @@ -217,6 +217,7 @@ struct RemoveFabricReq { pub struct NocCluster<'a> { data_ver: Dataver, epoch: Epoch, + rand: Rand, dev_att: &'a dyn DevAttDataFetcher, fabric_mgr: &'a RefCell, acl_mgr: &'a RefCell, @@ -237,6 +238,7 @@ impl<'a> NocCluster<'a> { Self { data_ver: Dataver::new(rand), epoch, + rand, dev_att, fabric_mgr, acl_mgr, @@ -566,7 +568,7 @@ impl<'a> NocCluster<'a> { Err(ErrorCode::UnsupportedAccess)?; } - let noc_keypair = KeyPair::new()?; + let noc_keypair = KeyPair::new(self.rand)?; let mut attest_challenge = [0u8; crypto::SYMM_KEY_LEN_BYTES]; attest_challenge.copy_from_slice(transaction.session().get_att_challenge()); diff --git a/matter/src/pairing/qr.rs b/matter/src/pairing/qr.rs index 3550dfb8..bb001aea 100644 --- a/matter/src/pairing/qr.rs +++ b/matter/src/pairing/qr.rs @@ -350,6 +350,7 @@ pub fn compute_qr_code<'a>( payload_base38_representation(&qr_code_data, buf) } +#[cfg(feature = "std")] fn compute_qr_version(qr_data: &str) -> i16 { match qr_data.len() { 0..=38 => 2, diff --git a/matter/src/secure_channel/case.rs b/matter/src/secure_channel/case.rs index 18011c99..3fa9249b 100644 --- a/matter/src/secure_channel/case.rs +++ b/matter/src/secure_channel/case.rs @@ -213,7 +213,7 @@ impl<'a> Case<'a> { ); // Create an ephemeral Key Pair - let key_pair = KeyPair::new()?; + let key_pair = KeyPair::new(self.rand)?; let _ = key_pair.get_public_key(&mut case_session.our_pub_key)?; // Derive the Shared Secret diff --git a/matter/src/secure_channel/crypto.rs b/matter/src/secure_channel/crypto.rs index 45d83592..d1292eb5 100644 --- a/matter/src/secure_channel/crypto.rs +++ b/matter/src/secure_channel/crypto.rs @@ -18,7 +18,8 @@ #[cfg(not(any( feature = "crypto_openssl", feature = "crypto_mbedtls", - feature = "crypto_esp_mbedtls" + feature = "crypto_esp_mbedtls", + feature = "crypto_rustcrypto" )))] pub use super::crypto_dummy::CryptoSpake2; #[cfg(feature = "crypto_esp_mbedtls")] @@ -27,3 +28,5 @@ pub use super::crypto_esp_mbedtls::CryptoSpake2; pub use super::crypto_mbedtls::CryptoSpake2; #[cfg(feature = "crypto_openssl")] pub use super::crypto_openssl::CryptoSpake2; +#[cfg(feature = "crypto_rustcrypto")] +pub use super::crypto_rustcrypto::CryptoSpake2; diff --git a/matter/src/secure_channel/crypto_dummy.rs b/matter/src/secure_channel/crypto_dummy.rs index 3933e797..414076eb 100644 --- a/matter/src/secure_channel/crypto_dummy.rs +++ b/matter/src/secure_channel/crypto_dummy.rs @@ -15,7 +15,10 @@ * limitations under the License. */ -use crate::error::{Error, ErrorCode}; +use crate::{ + error::{Error, ErrorCode}, + utils::rand::Rand, +}; #[allow(non_snake_case)] @@ -56,7 +59,7 @@ impl CryptoSpake2 { } #[allow(non_snake_case)] - pub fn get_pB(&mut self, _pB: &mut [u8]) -> Result<(), Error> { + pub fn get_pB(&mut self, _pB: &mut [u8], _rand: Rand) -> Result<(), Error> { Err(ErrorCode::Invalid.into()) } diff --git a/matter/src/secure_channel/crypto_mbedtls.rs b/matter/src/secure_channel/crypto_mbedtls.rs index de7ea487..8ddec407 100644 --- a/matter/src/secure_channel/crypto_mbedtls.rs +++ b/matter/src/secure_channel/crypto_mbedtls.rs @@ -18,7 +18,10 @@ use alloc::sync::Arc; use core::ops::{Mul, Sub}; -use crate::error::{Error, ErrorCode}; +use crate::{ + error::{Error, ErrorCode}, + utils::rand::Rand, +}; use byteorder::{ByteOrder, LittleEndian}; use log::error; @@ -132,7 +135,7 @@ impl CryptoSpake2 { } #[allow(non_snake_case)] - pub fn get_pB(&mut self, pB: &mut [u8]) -> Result<(), Error> { + pub fn get_pB(&mut self, pB: &mut [u8], _rand: Rand) -> Result<(), Error> { // From the SPAKE2+ spec (https://datatracker.ietf.org/doc/draft-bar-cfrg-spake2plus/) // for y // - select random y between 0 to p diff --git a/matter/src/secure_channel/crypto_openssl.rs b/matter/src/secure_channel/crypto_openssl.rs index de60fff2..dd4f8b1d 100644 --- a/matter/src/secure_channel/crypto_openssl.rs +++ b/matter/src/secure_channel/crypto_openssl.rs @@ -15,7 +15,10 @@ * limitations under the License. */ -use crate::error::{Error, ErrorCode}; +use crate::{ + error::{Error, ErrorCode}, + utils::rand::Rand, +}; use byteorder::{ByteOrder, LittleEndian}; use log::error; @@ -116,6 +119,7 @@ impl CryptoSpake2 { Ok(()) } + #[allow(non_snake_case)] pub fn set_L(&mut self, l: &[u8]) -> Result<(), Error> { self.L = EcPoint::from_bytes(&self.group, l, &mut self.bn_ctx)?; Ok(()) @@ -134,7 +138,7 @@ impl CryptoSpake2 { } #[allow(non_snake_case)] - pub fn get_pB(&mut self, pB: &mut [u8]) -> Result<(), Error> { + pub fn get_pB(&mut self, pB: &mut [u8], _rand: Rand) -> Result<(), Error> { // From the SPAKE2+ spec (https://datatracker.ietf.org/doc/draft-bar-cfrg-spake2plus/) // for y // - select random y between 0 to p @@ -331,7 +335,6 @@ impl CryptoSpake2 { mod tests { use super::CryptoSpake2; - use crate::secure_channel::crypto::CryptoSpake2; use crate::secure_channel::spake2p_test_vectors::test_vectors::*; use openssl::bn::BigNum; use openssl::ec::{EcPoint, PointConversionForm}; diff --git a/matter/src/secure_channel/crypto_rustcrypto.rs b/matter/src/secure_channel/crypto_rustcrypto.rs index a3ec6289..2c1e50a9 100644 --- a/matter/src/secure_channel/crypto_rustcrypto.rs +++ b/matter/src/secure_channel/crypto_rustcrypto.rs @@ -21,11 +21,12 @@ use elliptic_curve::ops::*; use elliptic_curve::sec1::{FromEncodedPoint, ToEncodedPoint}; use elliptic_curve::Field; use elliptic_curve::PrimeField; +use rand_core::CryptoRng; +use rand_core::RngCore; use sha2::Digest; use crate::error::Error; - -use super::crypto::CryptoSpake2; +use crate::utils::rand::Rand; const MATTER_M_BIN: [u8; 65] = [ 0x04, 0x88, 0x6e, 0x2f, 0x97, 0xac, 0xe4, 0x6e, 0x55, 0xba, 0x9d, 0xd7, 0x24, 0x25, 0x79, 0xf2, @@ -44,7 +45,7 @@ const MATTER_N_BIN: [u8; 65] = [ #[allow(non_snake_case)] -pub struct CryptoRustCrypto { +pub struct CryptoSpake2 { xy: p256::Scalar, w0: p256::Scalar, w1: p256::Scalar, @@ -54,15 +55,15 @@ pub struct CryptoRustCrypto { pB: p256::EncodedPoint, } -impl CryptoSpake2 for CryptoRustCrypto { +impl CryptoSpake2 { #[allow(non_snake_case)] - fn new() -> Result { + pub fn new() -> Result { let M = p256::EncodedPoint::from_bytes(MATTER_M_BIN).unwrap(); let N = p256::EncodedPoint::from_bytes(MATTER_N_BIN).unwrap(); let L = p256::EncodedPoint::default(); let pB = p256::EncodedPoint::default(); - Ok(CryptoRustCrypto { + Ok(Self { xy: p256::Scalar::ZERO, w0: p256::Scalar::ZERO, w1: p256::Scalar::ZERO, @@ -74,7 +75,7 @@ impl CryptoSpake2 for CryptoRustCrypto { } // Computes w0 from w0s respectively - fn set_w0_from_w0s(&mut self, w0s: &[u8]) -> Result<(), Error> { + pub fn set_w0_from_w0s(&mut self, w0s: &[u8]) -> Result<(), Error> { // From the Matter Spec, // w0 = w0s mod p // where p is the order of the curve @@ -103,7 +104,7 @@ impl CryptoSpake2 for CryptoRustCrypto { Ok(()) } - fn set_w1_from_w1s(&mut self, w1s: &[u8]) -> Result<(), Error> { + pub fn set_w1_from_w1s(&mut self, w1s: &[u8]) -> Result<(), Error> { // From the Matter Spec, // w1 = w1s mod p // where p is the order of the curve @@ -132,14 +133,14 @@ impl CryptoSpake2 for CryptoRustCrypto { Ok(()) } - fn set_w0(&mut self, w0: &[u8]) -> Result<(), Error> { + pub fn set_w0(&mut self, w0: &[u8]) -> Result<(), Error> { self.w0 = p256::Scalar::from_repr(*elliptic_curve::generic_array::GenericArray::from_slice(w0)) .unwrap(); Ok(()) } - fn set_w1(&mut self, w1: &[u8]) -> Result<(), Error> { + pub fn set_w1(&mut self, w1: &[u8]) -> Result<(), Error> { self.w1 = p256::Scalar::from_repr(*elliptic_curve::generic_array::GenericArray::from_slice(w1)) .unwrap(); @@ -148,12 +149,13 @@ impl CryptoSpake2 for CryptoRustCrypto { #[allow(non_snake_case)] #[allow(dead_code)] - fn set_L(&mut self, l: &[u8]) -> Result<(), Error> { + pub fn set_L(&mut self, l: &[u8]) -> Result<(), Error> { self.L = p256::EncodedPoint::from_bytes(l).unwrap(); Ok(()) } - fn set_L_from_w1s(&mut self, w1s: &[u8]) -> Result<(), Error> { + #[allow(non_snake_case)] + pub fn set_L_from_w1s(&mut self, w1s: &[u8]) -> Result<(), Error> { // From the Matter spec, // L = w1 * P // where P is the generator of the underlying elliptic curve @@ -163,14 +165,14 @@ impl CryptoSpake2 for CryptoRustCrypto { } #[allow(non_snake_case)] - fn get_pB(&mut self, pB: &mut [u8]) -> Result<(), Error> { + pub fn get_pB(&mut self, pB: &mut [u8], rand: Rand) -> Result<(), Error> { // From the SPAKE2+ spec (https://datatracker.ietf.org/doc/draft-bar-cfrg-spake2plus/) // for y // - select random y between 0 to p // - Y = y*P + w0*N // - pB = Y - let mut rng = rand::thread_rng(); - self.xy = p256::Scalar::random(&mut rng); + let mut rand = RandRngCore(rand); + self.xy = p256::Scalar::random(&mut rand); let P = p256::AffinePoint::GENERATOR; let N = p256::AffinePoint::from_encoded_point(&self.N).unwrap(); @@ -182,7 +184,7 @@ impl CryptoSpake2 for CryptoRustCrypto { } #[allow(non_snake_case)] - fn get_TT_as_verifier( + pub fn get_TT_as_verifier( &mut self, context: &[u8], pA: &[u8], @@ -222,9 +224,7 @@ impl CryptoSpake2 for CryptoRustCrypto { Ok(()) } -} -impl CryptoRustCrypto { fn add_to_tt(tt: &mut sha2::Sha256, buf: &[u8]) -> Result<(), Error> { tt.update((buf.len() as u64).to_le_bytes()); if !buf.is_empty() { @@ -266,11 +266,11 @@ impl CryptoRustCrypto { let mut tmp = x * w0; let N_neg = N.neg(); - let Z = CryptoRustCrypto::do_add_mul(Y, x, N_neg, tmp)?; + let Z = Self::do_add_mul(Y, x, N_neg, tmp)?; // Cofactor for P256 is 1, so that is a No-Op tmp = w1 * w0; - let V = CryptoRustCrypto::do_add_mul(Y, w1, N_neg, tmp)?; + let V = Self::do_add_mul(Y, w1, N_neg, tmp)?; Ok((Z, V)) } @@ -297,27 +297,55 @@ impl CryptoRustCrypto { let tmp = y * w0; let M_neg = M.neg(); - let Z = CryptoRustCrypto::do_add_mul(X, y, M_neg, tmp)?; + let Z = Self::do_add_mul(X, y, M_neg, tmp)?; // Cofactor for P256 is 1, so that is a No-Op let V = (L * y).to_encoded_point(false); Ok((Z, V)) } } +pub struct RandRngCore(pub Rand); + +impl RngCore for RandRngCore { + fn next_u32(&mut self) -> u32 { + let mut buf = [0; 4]; + self.fill_bytes(&mut buf); + + u32::from_be_bytes(buf) + } + + fn next_u64(&mut self) -> u64 { + let mut buf = [0; 8]; + self.fill_bytes(&mut buf); + + u64::from_be_bytes(buf) + } + + fn fill_bytes(&mut self, dest: &mut [u8]) { + (self.0)(dest); + } + + fn try_fill_bytes(&mut self, dest: &mut [u8]) -> Result<(), rand_core::Error> { + self.fill_bytes(dest); + Ok(()) + } +} + +impl CryptoRng for RandRngCore {} + #[cfg(test)] mod tests { use super::*; use elliptic_curve::sec1::FromEncodedPoint; - use crate::secure_channel::crypto::CryptoSpake2; use crate::secure_channel::spake2p_test_vectors::test_vectors::*; #[test] #[allow(non_snake_case)] fn test_get_X() { for t in RFC_T { - let mut c = CryptoRustCrypto::new().unwrap(); + let mut c = CryptoSpake2::new().unwrap(); let x = p256::Scalar::from_repr( *elliptic_curve::generic_array::GenericArray::from_slice(&t.x), ) @@ -325,7 +353,7 @@ mod tests { c.set_w0(&t.w0).unwrap(); let P = p256::AffinePoint::GENERATOR; let M = p256::AffinePoint::from_encoded_point(&c.M).unwrap(); - let r: p256::EncodedPoint = CryptoRustCrypto::do_add_mul(P, x, M, c.w0).unwrap(); + let r: p256::EncodedPoint = CryptoSpake2::do_add_mul(P, x, M, c.w0).unwrap(); assert_eq!(&t.X, r.as_bytes()); } } @@ -334,7 +362,7 @@ mod tests { #[allow(non_snake_case)] fn test_get_Y() { for t in RFC_T { - let mut c = CryptoRustCrypto::new().unwrap(); + let mut c = CryptoSpake2::new().unwrap(); let y = p256::Scalar::from_repr( *elliptic_curve::generic_array::GenericArray::from_slice(&t.y), ) @@ -342,7 +370,7 @@ mod tests { c.set_w0(&t.w0).unwrap(); let P = p256::AffinePoint::GENERATOR; let N = p256::AffinePoint::from_encoded_point(&c.N).unwrap(); - let r = CryptoRustCrypto::do_add_mul(P, y, N, c.w0).unwrap(); + let r = CryptoSpake2::do_add_mul(P, y, N, c.w0).unwrap(); assert_eq!(&t.Y, r.as_bytes()); } } @@ -351,7 +379,7 @@ mod tests { #[allow(non_snake_case)] fn test_get_ZV_as_prover() { for t in RFC_T { - let mut c = CryptoRustCrypto::new().unwrap(); + let mut c = CryptoSpake2::new().unwrap(); let x = p256::Scalar::from_repr( *elliptic_curve::generic_array::GenericArray::from_slice(&t.x), ) @@ -361,7 +389,7 @@ mod tests { let Y = p256::EncodedPoint::from_bytes(t.Y).unwrap(); let Y = p256::AffinePoint::from_encoded_point(&Y).unwrap(); let N = p256::AffinePoint::from_encoded_point(&c.N).unwrap(); - let (Z, V) = CryptoRustCrypto::get_ZV_as_prover(c.w0, c.w1, N, Y, x).unwrap(); + let (Z, V) = CryptoSpake2::get_ZV_as_prover(c.w0, c.w1, N, Y, x).unwrap(); assert_eq!(&t.Z, Z.as_bytes()); assert_eq!(&t.V, V.as_bytes()); @@ -372,7 +400,7 @@ mod tests { #[allow(non_snake_case)] fn test_get_ZV_as_verifier() { for t in RFC_T { - let mut c = CryptoRustCrypto::new().unwrap(); + let mut c = CryptoSpake2::new().unwrap(); let y = p256::Scalar::from_repr( *elliptic_curve::generic_array::GenericArray::from_slice(&t.y), ) @@ -383,7 +411,7 @@ mod tests { let L = p256::EncodedPoint::from_bytes(t.L).unwrap(); let L = p256::AffinePoint::from_encoded_point(&L).unwrap(); let M = p256::AffinePoint::from_encoded_point(&c.M).unwrap(); - let (Z, V) = CryptoRustCrypto::get_ZV_as_verifier(c.w0, L, M, X, y).unwrap(); + let (Z, V) = CryptoSpake2::get_ZV_as_verifier(c.w0, L, M, X, y).unwrap(); assert_eq!(&t.Z, Z.as_bytes()); assert_eq!(&t.V, V.as_bytes()); diff --git a/matter/src/secure_channel/pake.rs b/matter/src/secure_channel/pake.rs index 84c5ba09..cd0ffaf7 100644 --- a/matter/src/secure_channel/pake.rs +++ b/matter/src/secure_channel/pake.rs @@ -282,7 +282,7 @@ impl Pake { let mut pB: [u8; 65] = [0; 65]; let mut cB: [u8; 32] = [0; 32]; sd.spake2p.start_verifier(&self.verifier)?; - sd.spake2p.handle_pA(pA, &mut pB, &mut cB)?; + sd.spake2p.handle_pA(pA, &mut pB, &mut cB, self.rand)?; let mut tw = TLVWriter::new(ctx.tx.get_writebuf()?); let resp = Pake1Resp { diff --git a/matter/src/secure_channel/spake2p.rs b/matter/src/secure_channel/spake2p.rs index 9be2d4de..1ee00b6b 100644 --- a/matter/src/secure_channel/spake2p.rs +++ b/matter/src/secure_channel/spake2p.rs @@ -196,13 +196,19 @@ impl Spake2P { } #[allow(non_snake_case)] - pub fn handle_pA(&mut self, pA: &[u8], pB: &mut [u8], cB: &mut [u8]) -> Result<(), Error> { + pub fn handle_pA( + &mut self, + pA: &[u8], + pB: &mut [u8], + cB: &mut [u8], + rand: Rand, + ) -> Result<(), Error> { if self.mode != Spake2Mode::Verifier(Spake2VerifierState::Init) { Err(ErrorCode::InvalidState)?; } if let Some(crypto_spake2) = &mut self.crypto_spake2 { - crypto_spake2.get_pB(pB)?; + crypto_spake2.get_pB(pB, rand)?; if let Some(context) = self.context.take() { let mut hash = [0u8; crypto::SHA256_HASH_LEN_BYTES]; context.finish(&mut hash)?; diff --git a/matter/src/transport/network.rs b/matter/src/transport/network.rs index 6cda9bcd..e03658b9 100644 --- a/matter/src/transport/network.rs +++ b/matter/src/transport/network.rs @@ -17,9 +17,9 @@ use core::fmt::{Debug, Display}; #[cfg(not(feature = "std"))] -use no_std_net::{IpAddr, Ipv4Addr, SocketAddr}; +pub use no_std_net::{IpAddr, Ipv4Addr, SocketAddr}; #[cfg(feature = "std")] -use std::net::{IpAddr, Ipv4Addr, SocketAddr}; +pub use std::net::{IpAddr, Ipv4Addr, SocketAddr}; #[derive(PartialEq, Copy, Clone)] pub enum Address { From e741cab89dfd89ae9be9813e0f03f279eeca9019 Mon Sep 17 00:00:00 2001 From: ivmarkov Date: Thu, 4 May 2023 06:13:36 +0000 Subject: [PATCH 38/72] More crypto fixes --- matter/Cargo.toml | 2 +- matter/src/crypto/crypto_rustcrypto.rs | 14 ++++++++++++-- 2 files changed, 13 insertions(+), 3 deletions(-) diff --git a/matter/Cargo.toml b/matter/Cargo.toml index e4a901a0..d0ce250c 100644 --- a/matter/Cargo.toml +++ b/matter/Cargo.toml @@ -69,7 +69,7 @@ p256 = { version = "0.13.0", default-features = false, features = ["arithmetic", elliptic-curve = { version = "0.13.2", optional = true } crypto-bigint = { version = "0.4", default-features = false, optional = true } rand_core = { version = "0.6", default-features = false, optional = true } -x509-cert = { version = "0.2.0", default-features = false, features = ["pem", "std"], optional = true } # TODO: requires `alloc` +x509-cert = { version = "0.2.0", default-features = false, features = ["pem"], optional = true } # TODO: requires `alloc` # to compute the check digit verhoeff = "1" diff --git a/matter/src/crypto/crypto_rustcrypto.rs b/matter/src/crypto/crypto_rustcrypto.rs index 6f975cb2..6212c96c 100644 --- a/matter/src/crypto/crypto_rustcrypto.rs +++ b/matter/src/crypto/crypto_rustcrypto.rs @@ -34,7 +34,7 @@ use p256::{ use sha2::Digest; use x509_cert::{ attr::AttributeType, - der::{asn1::BitString, Any, Encode}, + der::{asn1::BitString, Any, Encode, Writer}, name::RdnSequence, request::CertReq, spki::{AlgorithmIdentifier, SubjectPublicKeyInfoOwned}, @@ -205,7 +205,7 @@ impl KeyPair { attributes: Default::default(), }; let mut message = vec![]; - info.encode(&mut message).unwrap(); + info.encode(&mut VecWriter(&mut message)).unwrap(); // Can't use self.sign_msg as the signature has to be in DER format let private_key = self.private_key()?; @@ -375,3 +375,13 @@ impl<'a> ccm::aead::Buffer for SliceBuffer<'a> { self.len = len; } } + +struct VecWriter<'a>(&'a mut alloc::vec::Vec); + +impl<'a> Writer for VecWriter<'a> { + fn write(&mut self, slice: &[u8]) -> x509_cert::der::Result<()> { + self.0.extend_from_slice(slice); + + Ok(()) + } +} From 9d59c79674cab7e5d0c86310e18cdbf5e31953d2 Mon Sep 17 00:00:00 2001 From: ivmarkov Date: Thu, 4 May 2023 07:09:59 +0000 Subject: [PATCH 39/72] Colorizing is now no_std compatible --- matter/Cargo.toml | 2 +- matter/src/interaction_model/core.rs | 4 ++-- matter/src/transport/exchange.rs | 2 +- 3 files changed, 4 insertions(+), 4 deletions(-) diff --git a/matter/Cargo.toml b/matter/Cargo.toml index d0ce250c..3ccf3c4e 100644 --- a/matter/Cargo.toml +++ b/matter/Cargo.toml @@ -38,7 +38,7 @@ log = { version = "0.4.17", features = ["max_level_debug", "release_max_level_de no-std-net = "0.6" subtle = "2.4.1" safemem = "0.3.3" -colored = "2.0.0" # TODO: Requires STD +owo-colors = "3" # STD-only dependencies env_logger = { version = "0.10.0", default-features = false, optional = true } diff --git a/matter/src/interaction_model/core.rs b/matter/src/interaction_model/core.rs index 9e29bac4..82d2eb4e 100644 --- a/matter/src/interaction_model/core.rs +++ b/matter/src/interaction_model/core.rs @@ -29,10 +29,10 @@ use crate::{ session::Session, }, }; -use colored::Colorize; use log::{error, info}; use num; use num_derive::FromPrimitive; +use owo_colors::OwoColorize; use super::messages::{ ib::{AttrPath, DataVersionFilter}, @@ -43,7 +43,7 @@ use super::messages::{ #[macro_export] macro_rules! cmd_enter { ($e:expr) => {{ - use colored::Colorize; + use owo_colors::OwoColorize; info! {"{} {}", "Handling Command".cyan(), $e.cyan()} }}; } diff --git a/matter/src/transport/exchange.rs b/matter/src/transport/exchange.rs index 5a9bbcf7..57f666c1 100644 --- a/matter/src/transport/exchange.rs +++ b/matter/src/transport/exchange.rs @@ -15,11 +15,11 @@ * limitations under the License. */ -use colored::*; use core::any::Any; use core::fmt; use core::time::Duration; use log::{error, info, trace}; +use owo_colors::OwoColorize; use crate::error::{Error, ErrorCode}; use crate::interaction_model::core::{ResumeReadReq, ResumeSubscribeReq}; From a4b8b530143e6fc97426338bcea644dd7a4fd74d Mon Sep 17 00:00:00 2001 From: ivmarkov Date: Thu, 4 May 2023 09:35:47 +0000 Subject: [PATCH 40/72] Builds for STD with ESP IDF --- .gitignore | 1 + Cargo.toml | 8 ++++++ matter/Cargo.toml | 25 +++++++++++-------- matter/src/crypto/crypto_esp_mbedtls.rs | 4 ++- matter/src/crypto/mod.rs | 10 +++----- matter/src/error.rs | 10 +++++++- matter/src/mdns.rs | 2 +- matter/src/secure_channel/crypto.rs | 5 ++-- .../src/secure_channel/crypto_esp_mbedtls.rs | 12 ++++++++- matter/src/secure_channel/mod.rs | 5 ++-- matter/tests/common/mod.rs | 2 +- 11 files changed, 56 insertions(+), 28 deletions(-) diff --git a/.gitignore b/.gitignore index c8e9e486..636e67c7 100644 --- a/.gitignore +++ b/.gitignore @@ -1,3 +1,4 @@ target Cargo.lock .vscode +.embuild diff --git a/Cargo.toml b/Cargo.toml index 268671d4..f785efae 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -2,3 +2,11 @@ members = ["matter", "matter_macro_derive", "tools/tlv_tool"] exclude = ["examples/*"] + +# For compatibility with ESP IDF +[patch.crates-io] +smol = { git = "https://github.com/esp-rs-compat/smol" } +polling = { git = "https://github.com/esp-rs-compat/polling" } +socket2 = { git = "https://github.com/esp-rs-compat/socket2" } +chrono = { git = "https://github.com/ivmarkov/chrono" } +time = { git = "https://github.com/ivmarkov/time", branch = "master" } diff --git a/matter/Cargo.toml b/matter/Cargo.toml index 3ccf3c4e..a1f08401 100644 --- a/matter/Cargo.toml +++ b/matter/Cargo.toml @@ -16,13 +16,12 @@ path = "src/lib.rs" [features] default = ["std", "crypto_mbedtls", "backtrace"] -std = ["alloc", "env_logger", "chrono", "rand", "qrcode", "libmdns", "simple-mdns", "simple-dns", "smol"] +std = ["alloc", "env_logger", "chrono", "time", "rand", "qrcode", "libmdns", "simple-mdns", "simple-dns", "async-io", "smol"] backtrace = [] alloc = [] nightly = [] crypto_openssl = ["alloc", "openssl", "foreign-types", "hmac", "sha2"] -crypto_mbedtls = ["alloc", "mbedtls"] -crypto_esp_mbedtls = ["esp-idf-sys"] +crypto_mbedtls = ["alloc", "mbedtls", "esp-idf-sys"] crypto_rustcrypto = ["alloc", "sha2", "hmac", "pbkdf2", "hkdf", "aes", "ccm", "p256", "elliptic-curve", "crypto-bigint", "x509-cert", "rand_core"] [dependencies] @@ -39,23 +38,22 @@ no-std-net = "0.6" subtle = "2.4.1" safemem = "0.3.3" owo-colors = "3" +verhoeff = { version = "1", default-features = false } # STD-only dependencies -env_logger = { version = "0.10.0", default-features = false, optional = true } -chrono = { version = "0.4.23", optional = true, default-features = false, features = ["clock", "std"] } +chrono = { version = "=0.4.19", optional = true, default-features = false, features = ["clock", "std"] } # =0.4.19 for compatibility with ESP IDF +time = { version = "0.1", optional = true, default-features = false } rand = { version = "0.8.5", optional = true } qrcode = { version = "0.12", default-features = false, optional = true } # Print QR code -libmdns = { version = "0.7", optional = true } simple-mdns = { version = "0.4", features = ["sync"], optional = true } simple-dns = { version = "0.5", optional = true } astro-dnssd = { version = "0.3", optional = true } # On Linux needs avahi-compat-libdns_sd, i.e. on Ubuntu/Debian do `sudo apt-get install libavahi-compat-libdnssd-dev` zeroconf = { version = "0.10", optional = true } -smol = { version = "1.3.0", optional = true} +smol = { version = "1.2", optional = true } # =1.2 for compatibility with ESP IDF +async-io = { version = "=1.12", optional = true } # =1.2 for compatibility with ESP IDF # crypto openssl = { git = "https://github.com/sfackler/rust-openssl", optional = true } -mbedtls = { version = "0.9", optional = true } -esp-idf-sys = { version = "0.32", optional = true } # rust-crypto foreign-types = { version = "0.3.2", optional = true } @@ -71,8 +69,13 @@ crypto-bigint = { version = "0.4", default-features = false, optional = true } rand_core = { version = "0.6", default-features = false, optional = true } x509-cert = { version = "0.2.0", default-features = false, features = ["pem"], optional = true } # TODO: requires `alloc` -# to compute the check digit -verhoeff = "1" +[target.'cfg(not(target_os = "espidf"))'.dependencies] +mbedtls = { git = "https://github.com/fortanix/rust-mbedtls", optional = true } +env_logger = { version = "0.10.0", optional = true } +libmdns = { version = "0.7", optional = true } + +[target.'cfg(target_os = "espidf")'.dependencies] +esp-idf-sys = { version = "0.32", default-features = false, features = ["native"], optional = true } [[example]] name = "onoff_light" diff --git a/matter/src/crypto/crypto_esp_mbedtls.rs b/matter/src/crypto/crypto_esp_mbedtls.rs index cad046ba..2fff707b 100644 --- a/matter/src/crypto/crypto_esp_mbedtls.rs +++ b/matter/src/crypto/crypto_esp_mbedtls.rs @@ -18,6 +18,7 @@ use log::error; use crate::error::{Error, ErrorCode}; +use crate::utils::rand::Rand; pub fn hkdf_sha256(_salt: &[u8], _ikm: &[u8], _info: &[u8], _key: &mut [u8]) -> Result<(), Error> { error!("This API should never get called"); @@ -60,10 +61,11 @@ impl HmacSha256 { } } +#[derive(Debug)] pub struct KeyPair {} impl KeyPair { - pub fn new() -> Result { + pub fn new(_rand: Rand) -> Result { error!("This API should never get called"); Ok(Self {}) diff --git a/matter/src/crypto/mod.rs b/matter/src/crypto/mod.rs index 47d49b72..85c40b07 100644 --- a/matter/src/crypto/mod.rs +++ b/matter/src/crypto/mod.rs @@ -37,14 +37,14 @@ pub const ECDH_SHARED_SECRET_LEN_BYTES: usize = 32; pub const EC_SIGNATURE_LEN_BYTES: usize = 64; -#[cfg(feature = "crypto_esp_mbedtls")] +#[cfg(all(feature = "crypto_mbedtls", target_os = "espidf"))] mod crypto_esp_mbedtls; -#[cfg(feature = "crypto_esp_mbedtls")] +#[cfg(all(feature = "crypto_mbedtls", target_os = "espidf"))] pub use self::crypto_esp_mbedtls::*; -#[cfg(feature = "crypto_mbedtls")] +#[cfg(all(feature = "crypto_mbedtls", not(target_os = "espidf")))] mod crypto_mbedtls; -#[cfg(feature = "crypto_mbedtls")] +#[cfg(all(feature = "crypto_mbedtls", not(target_os = "espidf")))] pub use self::crypto_mbedtls::*; #[cfg(feature = "crypto_openssl")] @@ -60,14 +60,12 @@ pub use self::crypto_rustcrypto::*; #[cfg(not(any( feature = "crypto_openssl", feature = "crypto_mbedtls", - feature = "crypto_esp_mbedtls", feature = "crypto_rustcrypto" )))] pub mod crypto_dummy; #[cfg(not(any( feature = "crypto_openssl", feature = "crypto_mbedtls", - feature = "crypto_esp_mbedtls", feature = "crypto_rustcrypto" )))] pub use self::crypto_dummy::*; diff --git a/matter/src/error.rs b/matter/src/error.rs index 507ce4bf..e15cbb71 100644 --- a/matter/src/error.rs +++ b/matter/src/error.rs @@ -157,7 +157,7 @@ impl From for Error { } } -#[cfg(feature = "crypto_mbedtls")] +#[cfg(all(feature = "crypto_mbedtls", not(target_os = "espidf")))] impl From for Error { fn from(e: mbedtls::Error) -> Self { ::log::error!("Error in TLS: {}", e); @@ -165,6 +165,14 @@ impl From for Error { } } +#[cfg(all(feature = "crypto_mbedtls", target_os = "espidf"))] +impl From for Error { + fn from(e: esp_idf_sys::EspError) -> Self { + ::log::error!("Error in TLS: {}", e); + Self::new(ErrorCode::TLSStack) + } +} + #[cfg(feature = "crypto_rustcrypto")] impl From for Error { fn from(_e: ccm::aead::Error) -> Self { diff --git a/matter/src/mdns.rs b/matter/src/mdns.rs index eb19a9f1..defb1374 100644 --- a/matter/src/mdns.rs +++ b/matter/src/mdns.rs @@ -475,7 +475,7 @@ pub mod astro { // } // } -#[cfg(feature = "std")] +#[cfg(all(feature = "std", not(target_os = "espidf")))] pub mod libmdns { use super::Mdns; use crate::error::Error; diff --git a/matter/src/secure_channel/crypto.rs b/matter/src/secure_channel/crypto.rs index d1292eb5..027db690 100644 --- a/matter/src/secure_channel/crypto.rs +++ b/matter/src/secure_channel/crypto.rs @@ -18,13 +18,12 @@ #[cfg(not(any( feature = "crypto_openssl", feature = "crypto_mbedtls", - feature = "crypto_esp_mbedtls", feature = "crypto_rustcrypto" )))] pub use super::crypto_dummy::CryptoSpake2; -#[cfg(feature = "crypto_esp_mbedtls")] +#[cfg(all(feature = "crypto_mbedtls", target_os = "espidf"))] pub use super::crypto_esp_mbedtls::CryptoSpake2; -#[cfg(feature = "crypto_mbedtls")] +#[cfg(all(feature = "crypto_mbedtls", not(target_os = "espidf")))] pub use super::crypto_mbedtls::CryptoSpake2; #[cfg(feature = "crypto_openssl")] pub use super::crypto_openssl::CryptoSpake2; diff --git a/matter/src/secure_channel/crypto_esp_mbedtls.rs b/matter/src/secure_channel/crypto_esp_mbedtls.rs index 316276ba..d1d77eeb 100644 --- a/matter/src/secure_channel/crypto_esp_mbedtls.rs +++ b/matter/src/secure_channel/crypto_esp_mbedtls.rs @@ -16,6 +16,7 @@ */ use crate::error::Error; +use crate::utils::rand::Rand; const MATTER_M_BIN: [u8; 65] = [ 0x04, 0x88, 0x6e, 0x2f, 0x97, 0xac, 0xe4, 0x6e, 0x55, 0xba, 0x9d, 0xd7, 0x24, 0x25, 0x79, 0xf2, @@ -77,7 +78,16 @@ impl CryptoSpake2 { } #[allow(non_snake_case)] - pub fn get_pB(&mut self, pB: &mut [u8]) -> Result<(), Error> { + #[allow(dead_code)] + pub fn set_L_from_w1s(&mut self, w1s: &[u8]) -> Result<(), Error> { + // From the Matter spec, + // L = w1 * P + // where P is the generator of the underlying elliptic curve + Ok(()) + } + + #[allow(non_snake_case)] + pub fn get_pB(&mut self, pB: &mut [u8], _rand: Rand) -> Result<(), Error> { // From the SPAKE2+ spec (https://datatracker.ietf.org/doc/draft-bar-cfrg-spake2plus/) // for y // - select random y between 0 to p diff --git a/matter/src/secure_channel/mod.rs b/matter/src/secure_channel/mod.rs index 15417b3b..58020b44 100644 --- a/matter/src/secure_channel/mod.rs +++ b/matter/src/secure_channel/mod.rs @@ -20,13 +20,12 @@ pub mod common; #[cfg(not(any( feature = "crypto_openssl", feature = "crypto_mbedtls", - feature = "crypto_esp_mbedtls", feature = "crypto_rustcrypto" )))] mod crypto_dummy; -#[cfg(feature = "crypto_esp_mbedtls")] +#[cfg(all(feature = "crypto_mbedtls", target_os = "espidf"))] mod crypto_esp_mbedtls; -#[cfg(feature = "crypto_mbedtls")] +#[cfg(all(feature = "crypto_mbedtls", not(target_os = "espidf")))] mod crypto_mbedtls; #[cfg(feature = "crypto_openssl")] pub mod crypto_openssl; diff --git a/matter/tests/common/mod.rs b/matter/tests/common/mod.rs index 96682b98..94837fc1 100644 --- a/matter/tests/common/mod.rs +++ b/matter/tests/common/mod.rs @@ -22,7 +22,7 @@ pub mod handlers; pub mod im_engine; pub fn init_env_logger() { - #[cfg(feature = "std")] + #[cfg(all(feature = "std", not(target_os = "espidf")))] { let _ = env_logger::try_init(); } From 592d1ee028f667c61e89b16990e1c25ceea32e8a Mon Sep 17 00:00:00 2001 From: ivmarkov Date: Thu, 4 May 2023 10:57:21 +0000 Subject: [PATCH 41/72] Just use time-rs in no_std mode --- Cargo.toml | 2 -- matter/Cargo.toml | 5 ++- matter/src/cert/asn1_writer.rs | 31 +++++++++------- matter/src/cert/mod.rs | 60 +++++++++++-------------------- matter/src/cert/printer.rs | 13 ++++--- matter/src/core.rs | 18 ++-------- matter/src/secure_channel/case.rs | 24 ++++--------- matter/src/secure_channel/core.rs | 5 ++- matter/src/transport/mgr.rs | 6 ++-- matter/src/utils/epoch.rs | 45 ----------------------- matter/tests/common/im_engine.rs | 6 +--- tools/tlv_tool/src/main.rs | 3 +- 12 files changed, 63 insertions(+), 155 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index f785efae..2e964561 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -8,5 +8,3 @@ exclude = ["examples/*"] smol = { git = "https://github.com/esp-rs-compat/smol" } polling = { git = "https://github.com/esp-rs-compat/polling" } socket2 = { git = "https://github.com/esp-rs-compat/socket2" } -chrono = { git = "https://github.com/ivmarkov/chrono" } -time = { git = "https://github.com/ivmarkov/time", branch = "master" } diff --git a/matter/Cargo.toml b/matter/Cargo.toml index a1f08401..1e257574 100644 --- a/matter/Cargo.toml +++ b/matter/Cargo.toml @@ -16,7 +16,7 @@ path = "src/lib.rs" [features] default = ["std", "crypto_mbedtls", "backtrace"] -std = ["alloc", "env_logger", "chrono", "time", "rand", "qrcode", "libmdns", "simple-mdns", "simple-dns", "async-io", "smol"] +std = ["alloc", "env_logger", "rand", "qrcode", "libmdns", "simple-mdns", "simple-dns", "async-io", "smol"] backtrace = [] alloc = [] nightly = [] @@ -38,11 +38,10 @@ no-std-net = "0.6" subtle = "2.4.1" safemem = "0.3.3" owo-colors = "3" +time = { version = "0.3", default-features = false } verhoeff = { version = "1", default-features = false } # STD-only dependencies -chrono = { version = "=0.4.19", optional = true, default-features = false, features = ["clock", "std"] } # =0.4.19 for compatibility with ESP IDF -time = { version = "0.1", optional = true, default-features = false } rand = { version = "0.8.5", optional = true } qrcode = { version = "0.12", default-features = false, optional = true } # Print QR code simple-mdns = { version = "0.4", features = ["sync"], optional = true } diff --git a/matter/src/cert/asn1_writer.rs b/matter/src/cert/asn1_writer.rs index 4afd6b6c..87fac3c0 100644 --- a/matter/src/cert/asn1_writer.rs +++ b/matter/src/cert/asn1_writer.rs @@ -15,12 +15,14 @@ * limitations under the License. */ +use time::OffsetDateTime; + use super::{CertConsumer, MAX_DEPTH}; use crate::{ error::{Error, ErrorCode}, - utils::epoch::{UtcCalendar, MATTER_EPOCH_SECS}, + utils::epoch::MATTER_EPOCH_SECS, }; -use core::{fmt::Write, time::Duration}; +use core::fmt::Write; #[derive(Debug)] pub struct ASN1Writer<'a> { @@ -262,19 +264,24 @@ impl<'a> CertConsumer for ASN1Writer<'a> { self.write_str(0x06, oid) } - fn utctime(&mut self, _tag: &str, epoch: u32, utc_calendar: UtcCalendar) -> Result<(), Error> { + fn utctime(&mut self, _tag: &str, epoch: u32) -> Result<(), Error> { let matter_epoch = MATTER_EPOCH_SECS + epoch as u64; - let dt = utc_calendar(Duration::from_secs(matter_epoch as _)); + let dt = OffsetDateTime::from_unix_timestamp(matter_epoch as _).unwrap(); let mut time_str: heapless::String<32> = heapless::String::<32>::new(); - if dt.year >= 2050 { + if dt.year() >= 2050 { // If year is >= 2050, ASN.1 requires it to be Generalised Time write!( &mut time_str, "{:04}{:02}{:02}{:02}{:02}{:02}Z", - dt.year, dt.month, dt.day, dt.hour, dt.minute, dt.second + dt.year(), + dt.month() as u8, + dt.day(), + dt.hour(), + dt.minute(), + dt.second() ) .unwrap(); self.write_str(0x18, time_str.as_bytes()) @@ -282,12 +289,12 @@ impl<'a> CertConsumer for ASN1Writer<'a> { write!( &mut time_str, "{:02}{:02}{:02}{:02}{:02}{:02}Z", - dt.year % 100, - dt.month, - dt.day, - dt.hour, - dt.minute, - dt.second + dt.year() % 100, + dt.month() as u8, + dt.day(), + dt.hour(), + dt.minute(), + dt.second() ) .unwrap(); self.write_str(0x17, time_str.as_bytes()) diff --git a/matter/src/cert/mod.rs b/matter/src/cert/mod.rs index 0ec6dd56..8878622c 100644 --- a/matter/src/cert/mod.rs +++ b/matter/src/cert/mod.rs @@ -21,7 +21,7 @@ use crate::{ crypto::KeyPair, error::{Error, ErrorCode}, tlv::{self, FromTLV, OctetStr, TLVArray, TLVElement, TLVWriter, TagType, ToTLV}, - utils::{epoch::UtcCalendar, writebuf::WriteBuf}, + utils::writebuf::WriteBuf, }; use log::error; use num_derive::FromPrimitive; @@ -621,21 +621,17 @@ impl<'a> Cert<'a> { Ok(wb.as_slice().len()) } - pub fn as_asn1(&self, buf: &mut [u8], utc_calendar: UtcCalendar) -> Result { + pub fn as_asn1(&self, buf: &mut [u8]) -> Result { let mut w = ASN1Writer::new(buf); - self.encode(&mut w, Some(utc_calendar))?; + self.encode(&mut w)?; Ok(w.as_slice().len()) } - pub fn verify_chain_start(&self, utc_calendar: UtcCalendar) -> CertVerifier { - CertVerifier::new(self, utc_calendar) + pub fn verify_chain_start(&self) -> CertVerifier { + CertVerifier::new(self) } - fn encode( - &self, - w: &mut dyn CertConsumer, - utc_calendar: Option, - ) -> Result<(), Error> { + fn encode(&self, w: &mut dyn CertConsumer) -> Result<(), Error> { w.start_seq("")?; w.start_ctx("Version:", 0)?; @@ -654,10 +650,8 @@ impl<'a> Cert<'a> { self.issuer.encode("Issuer:", w)?; w.start_seq("Validity:")?; - if let Some(utc_calendar) = utc_calendar { - w.utctime("Not Before:", self.not_before, utc_calendar)?; - w.utctime("Not After:", self.not_after, utc_calendar)?; - } + w.utctime("Not Before:", self.not_before)?; + w.utctime("Not After:", self.not_after)?; w.end_seq()?; self.subject.encode("Subject:", w)?; @@ -689,7 +683,7 @@ impl<'a> fmt::Display for Cert<'a> { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { let mut printer = CertPrinter::new(f); let _ = self - .encode(&mut printer, None) + .encode(&mut printer) .map_err(|e| error!("Error decoding certificate: {}", e)); // Signature is not encoded by the Cert Decoder writeln!(f, "Signature: {:x?}", self.get_signature()) @@ -698,12 +692,11 @@ impl<'a> fmt::Display for Cert<'a> { pub struct CertVerifier<'a> { cert: &'a Cert<'a>, - utc_calendar: UtcCalendar, } impl<'a> CertVerifier<'a> { - pub fn new(cert: &'a Cert, utc_calendar: UtcCalendar) -> Self { - Self { cert, utc_calendar } + pub fn new(cert: &'a Cert) -> Self { + Self { cert } } pub fn add_cert(self, parent: &'a Cert) -> Result, Error> { @@ -711,7 +704,7 @@ impl<'a> CertVerifier<'a> { Err(ErrorCode::InvalidAuthKey)?; } let mut asn1 = [0u8; MAX_ASN1_CERT_SIZE]; - let len = self.cert.as_asn1(&mut asn1, self.utc_calendar)?; + let len = self.cert.as_asn1(&mut asn1)?; let asn1 = &asn1[..len]; let k = KeyPair::new_from_public(parent.get_pubkey())?; @@ -724,7 +717,7 @@ impl<'a> CertVerifier<'a> { })?; // TODO: other validation checks - Ok(CertVerifier::new(parent, self.utc_calendar)) + Ok(CertVerifier::new(parent)) } pub fn finalise(self) -> Result<(), Error> { @@ -751,7 +744,7 @@ pub trait CertConsumer { fn start_ctx(&mut self, tag: &str, id: u8) -> Result<(), Error>; fn end_ctx(&mut self) -> Result<(), Error>; fn oid(&mut self, tag: &str, oid: &[u8]) -> Result<(), Error>; - fn utctime(&mut self, tag: &str, epoch: u32, utc_calendar: UtcCalendar) -> Result<(), Error>; + fn utctime(&mut self, tag: &str, epoch: u32) -> Result<(), Error>; } const MAX_DEPTH: usize = 10; @@ -768,44 +761,36 @@ mod tests { use crate::tlv::{self, FromTLV, TLVWriter, TagType, ToTLV}; use crate::utils::writebuf::WriteBuf; - #[cfg(feature = "std")] #[test] fn test_asn1_encode_success() { { let mut asn1_buf = [0u8; 1000]; let c = Cert::new(&test_vectors::CHIP_CERT_INPUT1).unwrap(); - let len = c - .as_asn1(&mut asn1_buf, crate::utils::epoch::sys_utc_calendar) - .unwrap(); + let len = c.as_asn1(&mut asn1_buf).unwrap(); assert_eq!(&test_vectors::ASN1_OUTPUT1, &asn1_buf[..len]); } { let mut asn1_buf = [0u8; 1000]; let c = Cert::new(&test_vectors::CHIP_CERT_INPUT2).unwrap(); - let len = c - .as_asn1(&mut asn1_buf, crate::utils::epoch::sys_utc_calendar) - .unwrap(); + let len = c.as_asn1(&mut asn1_buf).unwrap(); assert_eq!(&test_vectors::ASN1_OUTPUT2, &asn1_buf[..len]); } { let mut asn1_buf = [0u8; 1000]; let c = Cert::new(&test_vectors::CHIP_CERT_TXT_IN_DN).unwrap(); - let len = c - .as_asn1(&mut asn1_buf, crate::utils::epoch::sys_utc_calendar) - .unwrap(); + let len = c.as_asn1(&mut asn1_buf).unwrap(); assert_eq!(&test_vectors::ASN1_OUTPUT_TXT_IN_DN, &asn1_buf[..len]); } } - #[cfg(feature = "std")] #[test] fn test_verify_chain_success() { let noc = Cert::new(&test_vectors::NOC1_SUCCESS).unwrap(); let icac = Cert::new(&test_vectors::ICAC1_SUCCESS).unwrap(); let rca = Cert::new(&test_vectors::RCA1_SUCCESS).unwrap(); - let a = noc.verify_chain_start(crate::utils::epoch::sys_utc_calendar); + let a = noc.verify_chain_start(); a.add_cert(&icac) .unwrap() .add_cert(&rca) @@ -814,7 +799,6 @@ mod tests { .unwrap(); } - #[cfg(feature = "std")] #[test] fn test_verify_chain_incomplete() { // The chain doesn't lead up to a self-signed certificate @@ -822,35 +806,33 @@ mod tests { use crate::error::ErrorCode; let noc = Cert::new(&test_vectors::NOC1_SUCCESS).unwrap(); let icac = Cert::new(&test_vectors::ICAC1_SUCCESS).unwrap(); - let a = noc.verify_chain_start(crate::utils::epoch::sys_utc_calendar); + let a = noc.verify_chain_start(); assert_eq!( Err(ErrorCode::InvalidAuthKey), a.add_cert(&icac).unwrap().finalise().map_err(|e| e.code()) ); } - #[cfg(feature = "std")] #[test] fn test_auth_key_chain_incorrect() { use crate::error::ErrorCode; let noc = Cert::new(&test_vectors::NOC1_AUTH_KEY_FAIL).unwrap(); let icac = Cert::new(&test_vectors::ICAC1_SUCCESS).unwrap(); - let a = noc.verify_chain_start(crate::utils::epoch::sys_utc_calendar); + let a = noc.verify_chain_start(); assert_eq!( Err(ErrorCode::InvalidAuthKey), a.add_cert(&icac).map(|_| ()).map_err(|e| e.code()) ); } - #[cfg(feature = "std")] #[test] fn test_cert_corrupted() { use crate::error::ErrorCode; let noc = Cert::new(&test_vectors::NOC1_CORRUPT_CERT).unwrap(); let icac = Cert::new(&test_vectors::ICAC1_SUCCESS).unwrap(); - let a = noc.verify_chain_start(crate::utils::epoch::sys_utc_calendar); + let a = noc.verify_chain_start(); assert_eq!( Err(ErrorCode::InvalidSignature), a.add_cert(&icac).map(|_| ()).map_err(|e| e.code()) diff --git a/matter/src/cert/printer.rs b/matter/src/cert/printer.rs index ae079573..a4c4efed 100644 --- a/matter/src/cert/printer.rs +++ b/matter/src/cert/printer.rs @@ -15,12 +15,11 @@ * limitations under the License. */ +use time::OffsetDateTime; + use super::{CertConsumer, MAX_DEPTH}; -use crate::{ - error::Error, - utils::epoch::{UtcCalendar, MATTER_EPOCH_SECS}, -}; -use core::{fmt, time::Duration}; +use crate::{error::Error, utils::epoch::MATTER_EPOCH_SECS}; +use core::fmt; pub struct CertPrinter<'a, 'b> { level: usize, @@ -123,10 +122,10 @@ impl<'a, 'b> CertConsumer for CertPrinter<'a, 'b> { } Ok(()) } - fn utctime(&mut self, tag: &str, epoch: u32, utc_calendar: UtcCalendar) -> Result<(), Error> { + fn utctime(&mut self, tag: &str, epoch: u32) -> Result<(), Error> { let matter_epoch = MATTER_EPOCH_SECS + epoch as u64; - let dt = utc_calendar(Duration::from_secs(matter_epoch as _)); + let dt = OffsetDateTime::from_unix_timestamp(matter_epoch as _).unwrap(); let _ = writeln!(self.f, "{} {} {:?}", SPACE[self.level], tag, dt); Ok(()) diff --git a/matter/src/core.rs b/matter/src/core.rs index 17452e3c..bed0e9e5 100644 --- a/matter/src/core.rs +++ b/matter/src/core.rs @@ -25,10 +25,7 @@ use crate::{ mdns::{Mdns, MdnsMgr}, pairing::{print_pairing_code_and_qr, DiscoveryCapabilities}, secure_channel::{pake::PaseMgr, spake2p::VerifierData}, - utils::{ - epoch::{Epoch, UtcCalendar}, - rand::Rand, - }, + utils::{epoch::Epoch, rand::Rand}, }; /// Device Commissioning Data @@ -48,17 +45,16 @@ pub struct Matter<'a> { pub mdns_mgr: RefCell>, pub epoch: Epoch, pub rand: Rand, - pub utc_calendar: UtcCalendar, pub dev_det: &'a BasicInfoConfig<'a>, } impl<'a> Matter<'a> { #[cfg(feature = "std")] pub fn new_default(dev_det: &'a BasicInfoConfig, mdns: &'a mut dyn Mdns, port: u16) -> Self { - use crate::utils::epoch::{sys_epoch, sys_utc_calendar}; + use crate::utils::epoch::sys_epoch; use crate::utils::rand::sys_rand; - Self::new(dev_det, mdns, sys_epoch, sys_rand, sys_utc_calendar, port) + Self::new(dev_det, mdns, sys_epoch, sys_rand, port) } /// Creates a new Matter object @@ -72,7 +68,6 @@ impl<'a> Matter<'a> { mdns: &'a mut dyn Mdns, epoch: Epoch, rand: Rand, - utc_calendar: UtcCalendar, port: u16, ) -> Self { Self { @@ -89,7 +84,6 @@ impl<'a> Matter<'a> { )), epoch, rand, - utc_calendar, dev_det, } } @@ -178,9 +172,3 @@ impl<'a> Borrow for Matter<'a> { &self.rand } } - -impl<'a> Borrow for Matter<'a> { - fn borrow(&self) -> &UtcCalendar { - &self.utc_calendar - } -} diff --git a/matter/src/secure_channel/case.rs b/matter/src/secure_channel/case.rs index 3fa9249b..fbd6da8b 100644 --- a/matter/src/secure_channel/case.rs +++ b/matter/src/secure_channel/case.rs @@ -32,7 +32,7 @@ use crate::{ proto_ctx::ProtoCtx, session::{CaseDetails, CloneData, NocCatIds, SessionMode}, }, - utils::{epoch::UtcCalendar, rand::Rand, writebuf::WriteBuf}, + utils::{rand::Rand, writebuf::WriteBuf}, }; #[derive(PartialEq)] @@ -70,16 +70,11 @@ impl CaseSession { pub struct Case<'a> { fabric_mgr: &'a RefCell, rand: Rand, - utc_calendar: UtcCalendar, } impl<'a> Case<'a> { - pub fn new(fabric_mgr: &'a RefCell, rand: Rand, utc_calendar: UtcCalendar) -> Self { - Self { - fabric_mgr, - rand, - utc_calendar, - } + pub fn new(fabric_mgr: &'a RefCell, rand: Rand) -> Self { + Self { fabric_mgr, rand } } pub fn casesigma3_handler( @@ -133,9 +128,7 @@ impl<'a> Case<'a> { if let Some(icac) = d.initiator_icac { initiator_icac = Some(Cert::new(icac.0)?); } - if let Err(e) = - Case::validate_certs(fabric, &initiator_noc, &initiator_icac, self.utc_calendar) - { + if let Err(e) = Case::validate_certs(fabric, &initiator_noc, &initiator_icac) { error!("Certificate Chain doesn't match: {}", e); common::create_sc_status_report(ctx.tx, common::SCStatusCodes::InvalidParameter, None)?; ctx.exch_ctx.exch.close(); @@ -339,13 +332,8 @@ impl<'a> Case<'a> { Ok(()) } - fn validate_certs( - fabric: &Fabric, - noc: &Cert, - icac: &Option, - utc_calendar: UtcCalendar, - ) -> Result<(), Error> { - let mut verifier = noc.verify_chain_start(utc_calendar); + fn validate_certs(fabric: &Fabric, noc: &Cert, icac: &Option) -> Result<(), Error> { + let mut verifier = noc.verify_chain_start(); if fabric.get_fabric_id() != noc.get_fabric_id()? { Err(ErrorCode::Invalid)?; diff --git a/matter/src/secure_channel/core.rs b/matter/src/secure_channel/core.rs index 653ad741..c2fe059f 100644 --- a/matter/src/secure_channel/core.rs +++ b/matter/src/secure_channel/core.rs @@ -24,7 +24,7 @@ use crate::{ secure_channel::common::*, tlv, transport::{proto_ctx::ProtoCtx, session::CloneData}, - utils::{epoch::UtcCalendar, rand::Rand}, + utils::rand::Rand, }; use log::{error, info}; use num; @@ -46,10 +46,9 @@ impl<'a> SecureChannel<'a> { fabric_mgr: &'a RefCell, mdns: &'a RefCell>, rand: Rand, - utc_calendar: UtcCalendar, ) -> Self { SecureChannel { - case: Case::new(fabric_mgr, rand, utc_calendar), + case: Case::new(fabric_mgr, rand), pase, mdns, } diff --git a/matter/src/transport/mgr.rs b/matter/src/transport/mgr.rs index 331c3625..eeff6ff9 100644 --- a/matter/src/transport/mgr.rs +++ b/matter/src/transport/mgr.rs @@ -29,7 +29,7 @@ use crate::secure_channel::common::PROTO_ID_SECURE_CHANNEL; use crate::secure_channel::core::SecureChannel; use crate::transport::mrp::ReliableMessage; use crate::transport::{exchange, network::Address, packet::Packet}; -use crate::utils::epoch::{Epoch, UtcCalendar}; +use crate::utils::epoch::Epoch; use crate::utils::rand::Rand; use super::proto_ctx::ProtoCtx; @@ -210,8 +210,7 @@ impl<'a> TransportMgr<'a> { + Borrow> + Borrow>> + Borrow - + Borrow - + Borrow, + + Borrow, >( matter: &'a T, ) -> Self { @@ -221,7 +220,6 @@ impl<'a> TransportMgr<'a> { matter.borrow(), matter.borrow(), *matter.borrow(), - *matter.borrow(), ), *matter.borrow(), *matter.borrow(), diff --git a/matter/src/utils/epoch.rs b/matter/src/utils/epoch.rs index 7d08bfe7..8236813b 100644 --- a/matter/src/utils/epoch.rs +++ b/matter/src/utils/epoch.rs @@ -2,60 +2,15 @@ use core::time::Duration; pub type Epoch = fn() -> Duration; -pub type UtcCalendar = fn(Duration) -> UtcDate; - pub const MATTER_EPOCH_SECS: u64 = 946684800; // Seconds from 1970/01/01 00:00:00 till 2000/01/01 00:00:00 UTC -#[derive(Default, Debug, Clone, Eq, PartialEq)] -pub struct UtcDate { - pub year: u16, - pub month: u8, // 1 - 12 - pub day: u8, // 1 - 31 - pub hour: u8, // 0 - 23 - pub minute: u8, - pub second: u8, - pub millis: u16, -} - pub fn dummy_epoch() -> Duration { Duration::from_secs(0) } -pub fn dummy_utc_calendar(_duration: Duration) -> UtcDate { - Default::default() -} - #[cfg(feature = "std")] pub fn sys_epoch() -> Duration { std::time::SystemTime::now() .duration_since(std::time::UNIX_EPOCH) .unwrap() } - -#[cfg(feature = "std")] -pub fn sys_utc_calendar(duration: Duration) -> UtcDate { - use chrono::{Datelike, TimeZone, Timelike}; - use log::warn; - - let dt = match chrono::Utc.timestamp_opt(duration.as_secs() as _, duration.subsec_nanos()) { - chrono::LocalResult::None => panic!("Invalid time"), - chrono::LocalResult::Single(s) => s, - chrono::LocalResult::Ambiguous(_, a) => { - warn!( - "Ambiguous time for epoch {:?}; returning latest timestamp: {a}", - duration - ); - a - } - }; - - UtcDate { - year: dt.year() as _, - month: dt.month() as _, - day: dt.day() as _, - hour: dt.hour() as _, - minute: dt.minute() as _, - second: dt.second() as _, - millis: (dt.nanosecond() / 1000) as _, - } -} diff --git a/matter/tests/common/im_engine.rs b/matter/tests/common/im_engine.rs index 6d398a4f..ce608c50 100644 --- a/matter/tests/common/im_engine.rs +++ b/matter/tests/common/im_engine.rs @@ -106,15 +106,11 @@ pub type DmHandler<'a> = handler_chain_type!(OnOffCluster, EchoCluster, Descript pub fn matter(mdns: &mut dyn Mdns) -> Matter<'_> { #[cfg(feature = "std")] use matter::utils::epoch::sys_epoch as epoch; - #[cfg(feature = "std")] - use matter::utils::epoch::sys_utc_calendar as utc_calendar; #[cfg(not(feature = "std"))] use matter::utils::epoch::dummy_epoch as epoch; - #[cfg(not(feature = "std"))] - use matter::utils::epoch::dummy_utc_calendar as utc_calendar; - Matter::new(&BASIC_INFO, mdns, epoch, dummy_rand, utc_calendar, 5540) + Matter::new(&BASIC_INFO, mdns, epoch, dummy_rand, 5540) } /// An Interaction Model Engine to facilitate easy testing diff --git a/tools/tlv_tool/src/main.rs b/tools/tlv_tool/src/main.rs index cc08cb47..54e53745 100644 --- a/tools/tlv_tool/src/main.rs +++ b/tools/tlv_tool/src/main.rs @@ -18,7 +18,6 @@ use clap::{App, Arg}; use matter::cert; use matter::tlv; -use matter::utils::epoch::sys_utc_calendar; use simple_logger::SimpleLogger; use std::process; @@ -95,7 +94,7 @@ fn main() { } else if m.is_present("as-asn1") { let mut asn1_cert = [0_u8; 1024]; let cert = cert::Cert::new(&tlv_list[..index]).unwrap(); - let len = cert.as_asn1(&mut asn1_cert, sys_utc_calendar).unwrap(); + let len = cert.as_asn1(&mut asn1_cert).unwrap(); println!("{:02x?}", &asn1_cert[..len]); } else { tlv::print_tlv_list(&tlv_list[..index]); From 870ae6f21c849712f1b4e2b698d1be767b481df3 Mon Sep 17 00:00:00 2001 From: ivmarkov Date: Thu, 4 May 2023 12:37:21 +0000 Subject: [PATCH 42/72] Move MATTER_PORT outside of STD-only udp module --- examples/onoff_light/src/main.rs | 2 +- matter/src/core.rs | 3 +++ matter/src/transport/udp.rs | 8 +------- 3 files changed, 5 insertions(+), 8 deletions(-) diff --git a/examples/onoff_light/src/main.rs b/examples/onoff_light/src/main.rs index 3b943058..9cae24b5 100644 --- a/examples/onoff_light/src/main.rs +++ b/examples/onoff_light/src/main.rs @@ -58,7 +58,7 @@ fn main() -> Result<(), impl Error> { let mut mdns = matter::mdns::libmdns::LibMdns::new()?; //let mut mdns = matter::mdns::DummyMdns {}; - let matter = Matter::new_default(&dev_info, &mut mdns, matter::transport::udp::MATTER_PORT); + let matter = Matter::new_default(&dev_info, &mut mdns, matter::MATTER_PORT); let dev_att = dev_att::HardCodedDevAtt::new(); diff --git a/matter/src/core.rs b/matter/src/core.rs index bed0e9e5..2fc1c3a2 100644 --- a/matter/src/core.rs +++ b/matter/src/core.rs @@ -28,6 +28,9 @@ use crate::{ utils::{epoch::Epoch, rand::Rand}, }; +/* The Matter Port */ +pub const MATTER_PORT: u16 = 5540; + /// Device Commissioning Data pub struct CommissioningData { /// The data like password or verifier that is required to authenticate diff --git a/matter/src/transport/udp.rs b/matter/src/transport/udp.rs index b29ca05a..909ab1ec 100644 --- a/matter/src/transport/udp.rs +++ b/matter/src/transport/udp.rs @@ -15,7 +15,7 @@ * limitations under the License. */ -use crate::error::*; +use crate::{error::*, MATTER_PORT}; use log::{info, warn}; use smol::net::{Ipv6Addr, UdpSocket}; @@ -27,12 +27,6 @@ pub struct UdpListener { socket: UdpSocket, } -// Currently matches with the one in connectedhomeip repo -pub const MAX_RX_BUF_SIZE: usize = 1583; - -/* The Matter Port */ -pub const MATTER_PORT: u16 = 5540; - impl UdpListener { pub async fn new() -> Result { let listener = UdpListener { From bd61c95c7dcc2ada2c7942a5ba8b6a0659514f94 Mon Sep 17 00:00:00 2001 From: ivmarkov Date: Fri, 5 May 2023 16:23:20 +0000 Subject: [PATCH 43/72] no_std needs default features switched off for several crates --- matter/Cargo.toml | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/matter/Cargo.toml b/matter/Cargo.toml index 1e257574..c356956e 100644 --- a/matter/Cargo.toml +++ b/matter/Cargo.toml @@ -26,17 +26,17 @@ crypto_rustcrypto = ["alloc", "sha2", "hmac", "pbkdf2", "hkdf", "aes", "ccm", "p [dependencies] matter_macro_derive = { path = "../matter_macro_derive" } -bitflags = "1.3" -byteorder = "1.4.3" +bitflags = { version = "1.3", default-features = false } +byteorder = { version = "1.4.3", default-features = false } heapless = "0.7.16" -num = "0.4" +num = { version = "0.4", default-features = false } num-derive = "0.3.3" -num-traits = "0.2.15" +num-traits = { version = "0.2.15", default-features = false } strum = { version = "0.24", features = ["derive"], default-features = false } log = { version = "0.4.17", features = ["max_level_debug", "release_max_level_debug"] } no-std-net = "0.6" -subtle = "2.4.1" -safemem = "0.3.3" +subtle = { version = "2.4.1", default-features = false } +safemem = { version = "0.3.3", default-features = false } owo-colors = "3" time = { version = "0.3", default-features = false } verhoeff = { version = "1", default-features = false } @@ -53,9 +53,9 @@ async-io = { version = "=1.12", optional = true } # =1.2 for compatibility with # crypto openssl = { git = "https://github.com/sfackler/rust-openssl", optional = true } +foreign-types = { version = "0.3.2", optional = true } # rust-crypto -foreign-types = { version = "0.3.2", optional = true } sha2 = { version = "0.10", default-features = false, optional = true } hmac = { version = "0.12", optional = true } pbkdf2 = { version = "0.12", optional = true } From 1e6cd69de8b0382809aa83adafd826b3ec8bff44 Mon Sep 17 00:00:00 2001 From: ivmarkov Date: Wed, 24 May 2023 10:07:11 +0000 Subject: [PATCH 44/72] built-in mDNS; memory optimizations --- matter/Cargo.toml | 22 +- matter/src/acl.rs | 57 +- matter/src/core.rs | 6 + matter/src/crypto/crypto_rustcrypto.rs | 2 +- matter/src/data_model/objects/handler.rs | 2 +- matter/src/data_model/sdm/failsafe.rs | 1 + .../data_model/sdm/general_commissioning.rs | 2 +- matter/src/data_model/sdm/noc.rs | 2 +- matter/src/error.rs | 10 +- matter/src/fabric.rs | 64 +- matter/src/interaction_model/core.rs | 2 + matter/src/interaction_model/messages.rs | 2 +- matter/src/mdns.rs | 738 +++++++++--------- matter/src/pairing/mod.rs | 2 +- matter/src/secure_channel/case.rs | 5 +- matter/src/secure_channel/common.rs | 3 + matter/src/tlv/parser.rs | 6 +- matter/src/tlv/traits.rs | 28 + matter/src/transport/exchange.rs | 3 +- matter/src/transport/mod.rs | 1 - matter/src/transport/network.rs | 14 +- matter/src/transport/packet.rs | 35 +- matter/src/transport/plain_hdr.rs | 4 +- matter/src/transport/proto_hdr.rs | 4 +- matter/src/transport/session.rs | 41 +- matter/src/transport/udp.rs | 133 ++-- matter/src/utils/mod.rs | 1 + matter/src/utils/parsebuf.rs | 30 +- matter/src/utils/select.rs | 35 + matter/src/utils/writebuf.rs | 12 + 30 files changed, 742 insertions(+), 525 deletions(-) create mode 100644 matter/src/utils/select.rs diff --git a/matter/Cargo.toml b/matter/Cargo.toml index c356956e..22ef439e 100644 --- a/matter/Cargo.toml +++ b/matter/Cargo.toml @@ -15,13 +15,14 @@ name = "matter" path = "src/lib.rs" [features] -default = ["std", "crypto_mbedtls", "backtrace"] -std = ["alloc", "env_logger", "rand", "qrcode", "libmdns", "simple-mdns", "simple-dns", "async-io", "smol"] +default = ["os", "crypto_rustcrypto"] +os = ["std", "backtrace", "critical-section/std", "embassy-sync/std", "embassy-time/std"] +std = ["alloc", "env_logger", "rand", "qrcode", "async-io", "smol", "esp-idf-sys/std"] backtrace = [] alloc = [] nightly = [] crypto_openssl = ["alloc", "openssl", "foreign-types", "hmac", "sha2"] -crypto_mbedtls = ["alloc", "mbedtls", "esp-idf-sys"] +crypto_mbedtls = ["alloc", "mbedtls"] crypto_rustcrypto = ["alloc", "sha2", "hmac", "pbkdf2", "hkdf", "aes", "ccm", "p256", "elliptic-curve", "crypto-bigint", "x509-cert", "rand_core"] [dependencies] @@ -40,14 +41,16 @@ safemem = { version = "0.3.3", default-features = false } owo-colors = "3" time = { version = "0.3", default-features = false } verhoeff = { version = "1", default-features = false } +embassy-futures = "0.1" +embassy-time = { version = "0.1.1", features = ["generic-queue-8"] } +embassy-sync = "0.2" +critical-section = "1.1.1" +domain = { version = "0.7.2", default_features = false } # STD-only dependencies rand = { version = "0.8.5", optional = true } qrcode = { version = "0.12", default-features = false, optional = true } # Print QR code -simple-mdns = { version = "0.4", features = ["sync"], optional = true } -simple-dns = { version = "0.5", optional = true } astro-dnssd = { version = "0.3", optional = true } # On Linux needs avahi-compat-libdns_sd, i.e. on Ubuntu/Debian do `sudo apt-get install libavahi-compat-libdnssd-dev` -zeroconf = { version = "0.10", optional = true } smol = { version = "1.2", optional = true } # =1.2 for compatibility with ESP IDF async-io = { version = "=1.12", optional = true } # =1.2 for compatibility with ESP IDF @@ -71,14 +74,9 @@ x509-cert = { version = "0.2.0", default-features = false, features = ["pem"], o [target.'cfg(not(target_os = "espidf"))'.dependencies] mbedtls = { git = "https://github.com/fortanix/rust-mbedtls", optional = true } env_logger = { version = "0.10.0", optional = true } -libmdns = { version = "0.7", optional = true } [target.'cfg(target_os = "espidf")'.dependencies] -esp-idf-sys = { version = "0.32", default-features = false, features = ["native"], optional = true } - -[[example]] -name = "onoff_light" -path = "../examples/onoff_light/src/main.rs" +esp-idf-sys = { version = "0.33", default-features = false, features = ["native"] } [[example]] diff --git a/matter/src/acl.rs b/matter/src/acl.rs index 77b8e5be..8bd8b701 100644 --- a/matter/src/acl.rs +++ b/matter/src/acl.rs @@ -22,7 +22,7 @@ use crate::{ error::{Error, ErrorCode}, fabric, interaction_model::messages::GenericPath, - tlv::{FromTLV, TLVElement, TLVList, TLVWriter, TagType, ToTLV}, + tlv::{self, FromTLV, TLVElement, TLVList, TLVWriter, TagType, ToTLV}, transport::session::{Session, SessionMode, MAX_CAT_IDS_PER_NOC}, utils::writebuf::WriteBuf, }; @@ -390,7 +390,7 @@ impl AclEntry { const MAX_ACL_ENTRIES: usize = ENTRIES_PER_FABRIC * fabric::MAX_SUPPORTED_FABRICS; -type AclEntries = [Option; MAX_ACL_ENTRIES]; +type AclEntries = heapless::Vec, MAX_ACL_ENTRIES>; pub struct AclMgr { entries: AclEntries, @@ -398,20 +398,16 @@ pub struct AclMgr { } impl AclMgr { + #[inline(always)] pub const fn new() -> Self { - const INIT: Option = None; - Self { - entries: [INIT; MAX_ACL_ENTRIES], + entries: AclEntries::new(), changed: false, } } pub fn erase_all(&mut self) -> Result<(), Error> { - for i in 0..MAX_ACL_ENTRIES { - self.entries[i] = None; - } - + self.entries.clear(); self.changed = true; Ok(()) @@ -427,14 +423,21 @@ impl AclMgr { if cnt >= ENTRIES_PER_FABRIC { Err(ErrorCode::NoSpace)?; } - let index = self - .entries - .iter() - .position(|a| a.is_none()) - .ok_or(ErrorCode::NoSpace)?; - self.entries[index] = Some(entry); - self.changed = true; + let slot = self.entries.iter().position(|a| a.is_none()); + + if slot.is_some() || self.entries.len() < MAX_ACL_ENTRIES { + if let Some(index) = slot { + self.entries[index] = Some(entry); + } else { + self.entries + .push(Some(entry)) + .map_err(|_| ErrorCode::NoSpace) + .unwrap(); + } + + self.changed = true; + } Ok(()) } @@ -459,17 +462,13 @@ impl AclMgr { } pub fn delete_for_fabric(&mut self, fab_idx: u8) -> Result<(), Error> { - for i in 0..MAX_ACL_ENTRIES { - if self.entries[i] - .filter(|e| e.fab_idx == Some(fab_idx)) - .is_some() - { - self.entries[i] = None; + for entry in &mut self.entries { + if entry.map(|e| e.fab_idx == Some(fab_idx)).unwrap_or(false) { + *entry = None; + self.changed = true; } } - self.changed = true; - Ok(()) } @@ -505,7 +504,7 @@ impl AclMgr { pub fn load(&mut self, data: &[u8]) -> Result<(), Error> { let root = TLVList::new(data).iter().next().ok_or(ErrorCode::Invalid)?; - self.entries = AclEntries::from_tlv(&root)?; + tlv::from_tlv(&mut self.entries, &root)?; self.changed = false; Ok(()) @@ -515,7 +514,9 @@ impl AclMgr { if self.changed { let mut wb = WriteBuf::new(buf); let mut tw = TLVWriter::new(&mut wb); - self.entries.to_tlv(&mut tw, TagType::Anonymous)?; + self.entries + .as_slice() + .to_tlv(&mut tw, TagType::Anonymous)?; self.changed = false; @@ -527,6 +528,10 @@ impl AclMgr { } } + pub fn is_changed(&self) -> bool { + self.changed + } + /// Traverse fabric specific entries to find the index /// /// If the ACL Mgr has 3 entries with fabric indexes, 1, 2, 1, then the list diff --git a/matter/src/core.rs b/matter/src/core.rs index 2fc1c3a2..fa960f78 100644 --- a/matter/src/core.rs +++ b/matter/src/core.rs @@ -53,6 +53,7 @@ pub struct Matter<'a> { impl<'a> Matter<'a> { #[cfg(feature = "std")] + #[inline(always)] pub fn new_default(dev_det: &'a BasicInfoConfig, mdns: &'a mut dyn Mdns, port: u16) -> Self { use crate::utils::epoch::sys_epoch; use crate::utils::rand::sys_rand; @@ -66,6 +67,7 @@ impl<'a> Matter<'a> { /// * dev_att: An object that implements the trait [DevAttDataFetcher]. Any Matter device /// requires a set of device attestation certificates and keys. It is the responsibility of /// this object to return the device attestation details when queried upon. + #[inline(always)] pub fn new( dev_det: &'a BasicInfoConfig, mdns: &'a mut dyn Mdns, @@ -113,6 +115,10 @@ impl<'a> Matter<'a> { self.acl_mgr.borrow_mut().store(buf) } + pub fn is_changed(&self) -> bool { + self.acl_mgr.borrow().is_changed() || self.fabric_mgr.borrow().is_changed() + } + pub fn start(&self, dev_comm: CommissioningData, buf: &mut [u8]) -> Result<(), Error> { let open_comm_window = self.fabric_mgr.borrow().is_empty(); if open_comm_window { diff --git a/matter/src/crypto/crypto_rustcrypto.rs b/matter/src/crypto/crypto_rustcrypto.rs index 6212c96c..19c288ea 100644 --- a/matter/src/crypto/crypto_rustcrypto.rs +++ b/matter/src/crypto/crypto_rustcrypto.rs @@ -51,7 +51,7 @@ type AesCcm = Ccm; extern crate alloc; -#[derive(Clone)] +#[derive(Debug, Clone)] pub struct Sha256 { hasher: sha2::Sha256, } diff --git a/matter/src/data_model/objects/handler.rs b/matter/src/data_model/objects/handler.rs index a5e2b9c9..143cad87 100644 --- a/matter/src/data_model/objects/handler.rs +++ b/matter/src/data_model/objects/handler.rs @@ -49,7 +49,7 @@ impl Handler for &mut T where T: Handler, { - fn read<'a>(&self, attr: &AttrDetails, encoder: AttrDataEncoder) -> Result<(), Error> { + fn read(&self, attr: &AttrDetails, encoder: AttrDataEncoder) -> Result<(), Error> { (**self).read(attr, encoder) } diff --git a/matter/src/data_model/sdm/failsafe.rs b/matter/src/data_model/sdm/failsafe.rs index 301baf91..043f5b93 100644 --- a/matter/src/data_model/sdm/failsafe.rs +++ b/matter/src/data_model/sdm/failsafe.rs @@ -49,6 +49,7 @@ pub struct FailSafe { } impl FailSafe { + #[inline(always)] pub const fn new() -> Self { Self { state: State::Idle } } diff --git a/matter/src/data_model/sdm/general_commissioning.rs b/matter/src/data_model/sdm/general_commissioning.rs index f2487ef2..b0cdff11 100644 --- a/matter/src/data_model/sdm/general_commissioning.rs +++ b/matter/src/data_model/sdm/general_commissioning.rs @@ -138,7 +138,7 @@ impl<'a> GenCommCluster<'a> { } pub fn failsafe(&self) -> &RefCell { - &self.failsafe + self.failsafe } pub fn read(&self, attr: &AttrDetails, encoder: AttrDataEncoder) -> Result<(), Error> { diff --git a/matter/src/data_model/sdm/noc.rs b/matter/src/data_model/sdm/noc.rs index f7346cc3..b8dda3cf 100644 --- a/matter/src/data_model/sdm/noc.rs +++ b/matter/src/data_model/sdm/noc.rs @@ -613,7 +613,7 @@ impl<'a> NocCluster<'a> { SessionMode::Pase => { let noc_data = transaction .session_mut() - .get_noc_data::() + .get_noc_data() .ok_or(ErrorCode::NoSession)?; let req = CommonReq::from_tlv(data).map_err(Error::map_invalid_command)?; diff --git a/matter/src/error.rs b/matter/src/error.rs index e15cbb71..c8da8208 100644 --- a/matter/src/error.rs +++ b/matter/src/error.rs @@ -165,11 +165,11 @@ impl From for Error { } } -#[cfg(all(feature = "crypto_mbedtls", target_os = "espidf"))] +#[cfg(target_os = "espidf")] impl From for Error { fn from(e: esp_idf_sys::EspError) -> Self { - ::log::error!("Error in TLS: {}", e); - Self::new(ErrorCode::TLSStack) + ::log::error!("Error in ESP: {}", e); + Self::new(ErrorCode::TLSStack) // TODO: Not a good mapping } } @@ -208,9 +208,9 @@ impl fmt::Debug for Error { #[cfg(all(feature = "std", feature = "backtrace"))] { - write!(f, "Error::{} {{\n", self)?; + writeln!(f, "Error::{} {{", self)?; write!(f, "{}", self.backtrace())?; - write!(f, "}}\n")?; + writeln!(f, "}}")?; } Ok(()) diff --git a/matter/src/fabric.rs b/matter/src/fabric.rs index 5658d4c1..f6f64ef7 100644 --- a/matter/src/fabric.rs +++ b/matter/src/fabric.rs @@ -27,7 +27,7 @@ use crate::{ error::{Error, ErrorCode}, group_keys::KeySet, mdns::{MdnsMgr, ServiceMode}, - tlv::{FromTLV, OctetStr, TLVList, TLVWriter, TagType, ToTLV, UtfStr}, + tlv::{self, FromTLV, OctetStr, TLVList, TLVWriter, TagType, ToTLV, UtfStr}, utils::writebuf::WriteBuf, }; @@ -184,7 +184,7 @@ impl Fabric { pub const MAX_SUPPORTED_FABRICS: usize = 3; -type FabricEntries = [Option; MAX_SUPPORTED_FABRICS]; +type FabricEntries = Vec, MAX_SUPPORTED_FABRICS>; pub struct FabricMgr { fabrics: FabricEntries, @@ -192,30 +192,25 @@ pub struct FabricMgr { } impl FabricMgr { + #[inline(always)] pub const fn new() -> Self { - const INIT: Option = None; - Self { - fabrics: [INIT; MAX_SUPPORTED_FABRICS], + fabrics: FabricEntries::new(), changed: false, } } pub fn load(&mut self, data: &[u8], mdns_mgr: &mut MdnsMgr) -> Result<(), Error> { - for fabric in &self.fabrics { - if let Some(fabric) = fabric { - mdns_mgr.unpublish_service(&fabric.mdns_service_name, ServiceMode::Commissioned)?; - } + for fabric in self.fabrics.iter().flatten() { + mdns_mgr.unpublish_service(&fabric.mdns_service_name, ServiceMode::Commissioned)?; } let root = TLVList::new(data).iter().next().ok_or(ErrorCode::Invalid)?; - self.fabrics = FabricEntries::from_tlv(&root)?; + tlv::from_tlv(&mut self.fabrics, &root)?; - for fabric in &self.fabrics { - if let Some(fabric) = fabric { - mdns_mgr.publish_service(&fabric.mdns_service_name, ServiceMode::Commissioned)?; - } + for fabric in self.fabrics.iter().flatten() { + mdns_mgr.publish_service(&fabric.mdns_service_name, ServiceMode::Commissioned)?; } self.changed = false; @@ -228,7 +223,9 @@ impl FabricMgr { let mut wb = WriteBuf::new(buf); let mut tw = TLVWriter::new(&mut wb); - self.fabrics.to_tlv(&mut tw, TagType::Anonymous)?; + self.fabrics + .as_slice() + .to_tlv(&mut tw, TagType::Anonymous)?; self.changed = false; @@ -240,20 +237,32 @@ impl FabricMgr { } } + pub fn is_changed(&self) -> bool { + self.changed + } + pub fn add(&mut self, f: Fabric, mdns_mgr: &mut MdnsMgr) -> Result { - for (index, fabric) in self.fabrics.iter_mut().enumerate() { - if fabric.is_none() { - mdns_mgr.publish_service(&f.mdns_service_name, ServiceMode::Commissioned)?; + let slot = self.fabrics.iter().position(|x| x.is_none()); - *fabric = Some(f); + if slot.is_some() || self.fabrics.len() < MAX_SUPPORTED_FABRICS { + mdns_mgr.publish_service(&f.mdns_service_name, ServiceMode::Commissioned)?; + self.changed = true; - self.changed = true; + if let Some(index) = slot { + self.fabrics[index] = Some(f); + + Ok((index + 1) as u8) + } else { + self.fabrics + .push(Some(f)) + .map_err(|_| ErrorCode::NoSpace) + .unwrap(); - return Ok((index + 1) as u8); + Ok(self.fabrics.len() as u8) } + } else { + Err(ErrorCode::NoSpace.into()) } - - Err(ErrorCode::NoSpace.into()) } pub fn remove(&mut self, fab_idx: u8, mdns_mgr: &mut MdnsMgr) -> Result<(), Error> { @@ -311,15 +320,14 @@ impl FabricMgr { } pub fn set_label(&mut self, index: u8, label: &str) -> Result<(), Error> { - if !label.is_empty() { - if self + if !label.is_empty() + && self .fabrics .iter() .filter_map(|f| f.as_ref()) .any(|f| f.label == label) - { - return Err(ErrorCode::Invalid.into()); - } + { + return Err(ErrorCode::Invalid.into()); } let index = (index - 1) as usize; diff --git a/matter/src/interaction_model/core.rs b/matter/src/interaction_model/core.rs index 82d2eb4e..e24ec079 100644 --- a/matter/src/interaction_model/core.rs +++ b/matter/src/interaction_model/core.rs @@ -605,6 +605,7 @@ impl<'a> SubscribeReq<'a> { } } +#[derive(Debug)] pub struct ResumeReadReq { pub paths: heapless::Vec, pub filters: heapless::Vec, @@ -664,6 +665,7 @@ impl ResumeReadReq { } } +#[derive(Debug)] pub struct ResumeSubscribeReq { pub subscription_id: u32, pub paths: heapless::Vec, diff --git a/matter/src/interaction_model/messages.rs b/matter/src/interaction_model/messages.rs index edf65db9..bfd8a8b4 100644 --- a/matter/src/interaction_model/messages.rs +++ b/matter/src/interaction_model/messages.rs @@ -77,7 +77,7 @@ pub mod msg { EventPath, }; - #[derive(Default, FromTLV, ToTLV)] + #[derive(Debug, Default, FromTLV, ToTLV)] #[tlvargs(lifetime = "'a")] pub struct SubscribeReq<'a> { pub keep_subs: bool, diff --git a/matter/src/mdns.rs b/matter/src/mdns.rs index defb1374..a1876833 100644 --- a/matter/src/mdns.rs +++ b/matter/src/mdns.rs @@ -109,6 +109,7 @@ pub struct MdnsMgr<'a> { } impl<'a> MdnsMgr<'a> { + #[inline(always)] pub fn new( vid: u16, pid: u16, @@ -212,36 +213,252 @@ impl<'a> MdnsMgr<'a> { } } -#[cfg(all(feature = "std", feature = "astro-dnssd"))] -pub mod astro { - use std::collections::HashMap; +pub mod builtin { + use core::cell::RefCell; + use core::fmt::Write; + use core::pin::pin; + use core::str::FromStr; + + use domain::base::header::Flags; + use domain::base::iana::Class; + use domain::base::octets::{Octets256, Octets64, OctetsBuilder}; + use domain::base::{Dname, MessageBuilder, Record, ShortBuf}; + use domain::rdata::{Aaaa, Ptr, Srv, Txt, A}; + use embassy_futures::select::select; + use embassy_sync::blocking_mutex::raw::NoopRawMutex; + use embassy_time::{Duration, Timer}; + use log::info; - use super::Mdns; use crate::error::{Error, ErrorCode}; - use astro_dnssd::{DNSServiceBuilder, RegisteredDnsService}; - use log::info; + use crate::transport::network::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr}; + use crate::transport::udp::UdpListener; + use crate::utils::select::EitherUnwrap; - #[derive(Debug, Clone, Eq, PartialEq, Hash)] - pub struct ServiceId { - name: String, - service: String, - protocol: String, + const IP_BROADCAST_ADDRS: [SocketAddr; 2] = [ + SocketAddr::new(IpAddr::V4(Ipv4Addr::new(224, 0, 0, 251)), 5353), + SocketAddr::new( + IpAddr::V6(Ipv6Addr::new(0xff02, 0, 0, 0, 0, 0, 0, 0x00fb)), + 5353, + ), + ]; + + const IP_BIND_ADDR: SocketAddr = SocketAddr::new(IpAddr::V6(Ipv6Addr::UNSPECIFIED), 5353); + + pub fn create_record( + id: u16, + hostname: &str, + ip: [u8; 4], + ipv6: Option<[u8; 16]>, + + ttl_sec: u32, + + name: &str, + service: &str, + protocol: &str, port: u16, + service_subtypes: &[&str], + txt_kvs: &[(&str, &str)], + + buffer: &mut [u8], + ) -> Result { + let target = domain::base::octets::Octets2048::new(); + let message = MessageBuilder::from_target(target)?; + + let mut message = message.answer(); + + let mut ptr_str = heapless::String::<40>::new(); + write!(ptr_str, "{}.{}.local", service, protocol).unwrap(); + + let mut dname = heapless::String::<60>::new(); + write!(dname, "{}.{}.{}.local", name, service, protocol).unwrap(); + + let mut hname = heapless::String::<40>::new(); + write!(hname, "{}.local", hostname).unwrap(); + + let ptr: Dname = Dname::from_str(&ptr_str).unwrap(); + let record: Record, Ptr<_>> = Record::new( + Dname::from_str("_services._dns-sd._udp.local").unwrap(), + Class::In, + ttl_sec, + Ptr::new(ptr), + ); + message.push(record)?; + + let t: Dname = Dname::from_str(&dname).unwrap(); + let record: Record, Ptr<_>> = Record::new( + Dname::from_str(&ptr_str).unwrap(), + Class::In, + ttl_sec, + Ptr::new(t), + ); + message.push(record)?; + + for sub_srv in service_subtypes { + let mut ptr_str = heapless::String::<40>::new(); + write!(ptr_str, "{}._sub.{}.{}.local", sub_srv, service, protocol).unwrap(); + + let ptr: Dname = Dname::from_str(&ptr_str).unwrap(); + let record: Record, Ptr<_>> = Record::new( + Dname::from_str("_services._dns-sd._udp.local").unwrap(), + Class::In, + ttl_sec, + Ptr::new(ptr), + ); + message.push(record)?; + + let t: Dname = Dname::from_str(&dname).unwrap(); + let record: Record, Ptr<_>> = Record::new( + Dname::from_str(&ptr_str).unwrap(), + Class::In, + ttl_sec, + Ptr::new(t), + ); + message.push(record)?; + } + + let target: Dname = Dname::from_str(&hname).unwrap(); + let record: Record, Srv<_>> = Record::new( + Dname::from_str(&dname).unwrap(), + Class::In, + ttl_sec, + Srv::new(0, 0, port, target), + ); + message.push(record)?; + + // only way I found to create multiple parts in a Txt + // each slice is the length and then the data + let mut octets = Octets256::new(); + //octets.append_slice(&[1u8, b'X']).unwrap(); + //octets.append_slice(&[2u8, b'A', b'B']).unwrap(); + //octets.append_slice(&[0u8]).unwrap(); + for (k, v) in txt_kvs { + octets + .append_slice(&[(k.len() + v.len() + 1) as u8]) + .unwrap(); + octets.append_slice(k.as_bytes()).unwrap(); + octets.append_slice(&[b'=']).unwrap(); + octets.append_slice(v.as_bytes()).unwrap(); + } + + let txt = Txt::from_octets(&mut octets).unwrap(); + + let record: Record, Txt<_>> = + Record::new(Dname::from_str(&dname).unwrap(), Class::In, ttl_sec, txt); + message.push(record)?; + + let record: Record, A> = Record::new( + Dname::from_str(&hname).unwrap(), + Class::In, + ttl_sec, + A::from_octets(ip[0], ip[1], ip[2], ip[3]), + ); + message.push(record)?; + + if let Some(ipv6) = ipv6 { + let record: Record, Aaaa> = Record::new( + Dname::from_str(&hname).unwrap(), + Class::In, + ttl_sec, + Aaaa::new(ipv6.into()), + ); + message.push(record)?; + } + + let headerb = message.header_mut(); + headerb.set_id(id); + headerb.set_opcode(domain::base::iana::Opcode::Query); + headerb.set_rcode(domain::base::iana::Rcode::NoError); + + let mut flags = Flags::new(); + flags.qr = true; + flags.aa = true; + headerb.set_flags(flags); + + let target = message.finish(); + + buffer[..target.len()].copy_from_slice(target.as_ref()); + + Ok(target.len()) } - pub struct AstroMdns { - services: HashMap, + pub type Notification = embassy_sync::signal::Signal; + + #[derive(Debug, Clone)] + struct MdnsEntry { + key: heapless::String<64>, + record: heapless::Vec, } - impl AstroMdns { - pub fn new() -> Result { - Ok(Self { - services: HashMap::new(), - }) + impl MdnsEntry { + #[inline(always)] + const fn new() -> Self { + Self { + key: heapless::String::new(), + record: heapless::Vec::new(), + } } + } + + pub struct Mdns<'a> { + id: u16, + hostname: &'a str, + ip: [u8; 4], + ipv6: Option<[u8; 16]>, + entries: RefCell>, + notification: Notification, + udp: RefCell>, + } + impl<'a> Mdns<'a> { + #[inline(always)] + pub const fn new(id: u16, hostname: &'a str, ip: [u8; 4], ipv6: Option<[u8; 16]>) -> Self { + Self { + id, + hostname, + ip, + ipv6, + entries: RefCell::new(heapless::Vec::new()), + notification: Notification::new(), + udp: RefCell::new(None), + } + } + + pub fn split(&mut self) -> (MdnsApi<'_, 'a>, MdnsRunner<'_, 'a>) { + (MdnsApi(&*self), MdnsRunner(&*self)) + } + + async fn bind(&self) -> Result<(), Error> { + if self.udp.borrow().is_none() { + *self.udp.borrow_mut() = Some(UdpListener::new(IP_BIND_ADDR).await?); + } + + Ok(()) + } + + pub fn close(&mut self) { + *self.udp.borrow_mut() = None; + } + + fn key( + &self, + name: &str, + service: &str, + protocol: &str, + port: u16, + ) -> heapless::String<64> { + let mut key = heapless::String::new(); + + write!(&mut key, "{name}.{service}.{protocol}.{port}").unwrap(); + + key + } + } + + pub struct MdnsApi<'a, 'b>(&'a Mdns<'b>); + + impl<'a, 'b> MdnsApi<'a, 'b> { pub fn add( - &mut self, + &self, name: &str, service: &str, protocol: &str, @@ -250,66 +467,142 @@ pub mod astro { txt_kvs: &[(&str, &str)], ) -> Result<(), Error> { info!( - "Registering mDNS service {}/{}.{} [{:?}]/{}", - name, service, protocol, service_subtypes, port + "Registering mDNS service {}/{}.{} [{:?}]/{}, keys [{:?}]", + name, service, protocol, service_subtypes, port, txt_kvs ); - let _ = self.remove(name, service, protocol, port); + let key = self.0.key(name, service, protocol, port); - let composite_service_type = if !service_subtypes.is_empty() { - format!("{}.{},{}", service, protocol, service_subtypes.join(",")) - } else { - format!("{}.{}", service, protocol) - }; + let mut entries = self.0.entries.borrow_mut(); - let mut builder = DNSServiceBuilder::new(&composite_service_type, port).with_name(name); + entries.retain(|entry| entry.key != key); + entries + .push(MdnsEntry::new()) + .map_err(|_| ErrorCode::NoSpace)?; - for kvs in txt_kvs { - info!("mDNS TXT key {} val {}", kvs.0, kvs.1); - builder = builder.with_key_value(kvs.0.to_string(), kvs.1.to_string()); + let entry = entries.iter_mut().last().unwrap(); + entry + .record + .resize(1024, 0) + .map_err(|_| ErrorCode::NoSpace) + .unwrap(); + + match create_record( + self.0.id, + self.0.hostname, + self.0.ip, + self.0.ipv6, + 60, /*ttl_sec*/ + name, + service, + protocol, + port, + service_subtypes, + txt_kvs, + &mut entry.record, + ) { + Ok(len) => entry.record.truncate(len), + Err(_) => { + entries.pop(); + Err(ErrorCode::NoSpace)?; + } } - let svc = builder.register().map_err(|_| ErrorCode::MdnsError)?; - - self.services.insert( - ServiceId { - name: name.into(), - service: service.into(), - protocol: protocol.into(), - port, - }, - svc, - ); + self.0.notification.signal(()); Ok(()) } pub fn remove( - &mut self, + &self, name: &str, service: &str, protocol: &str, port: u16, ) -> Result<(), Error> { - let id = ServiceId { - name: name.into(), - service: service.into(), - protocol: protocol.into(), - port, - }; + info!( + "Deregistering mDNS service {}/{}.{}/{}", + name, service, protocol, port + ); - if self.services.remove(&id).is_some() { - info!( - "Deregistering mDNS service {}/{}.{}/{}", - name, service, protocol, port - ); + let key = self.0.key(name, service, protocol, port); + + let mut entries = self.0.entries.borrow_mut(); + + let old_len = entries.len(); + + entries.retain(|entry| entry.key != key); + + if entries.len() != old_len { + self.0.notification.signal(()); } Ok(()) } } - impl Mdns for AstroMdns { + pub struct MdnsRunner<'a, 'b>(&'a Mdns<'b>); + + impl<'a, 'b> MdnsRunner<'a, 'b> { + pub async fn run(&mut self) -> Result<(), Error> { + let mut broadcast = pin!(self.broadcast()); + let mut respond = pin!(self.respond()); + + select(&mut broadcast, &mut respond).await.unwrap() + } + + async fn broadcast(&self) -> Result<(), Error> { + loop { + select( + self.0.notification.wait(), + Timer::after(Duration::from_secs(30)), + ) + .await; + + let mut index = 0; + + while let Some(entry) = self + .0 + .entries + .borrow() + .get(index) + .map(|entry| entry.clone()) + { + info!("Broadasting mDNS entry {}", &entry.key); + + self.0.bind().await?; + + let udp = self.0.udp.borrow(); + let udp = udp.as_ref().unwrap(); + + for addr in IP_BROADCAST_ADDRS { + udp.send(addr, &entry.record).await?; + } + + index += 1; + } + } + } + + async fn respond(&self) -> Result<(), Error> { + loop { + let mut buf = [0; 1580]; + + let udp = self.0.udp.borrow(); + let udp = udp.as_ref().unwrap(); + + let (_len, _addr) = udp.recv(&mut buf).await?; + + info!("Received UDP packet"); + + // TODO: Process the incoming packed and only answer what we are being queried about + + self.0.notification.signal(()); + } + } + } + + impl<'a, 'b> super::Mdns for MdnsApi<'a, 'b> { fn add( &mut self, name: &str, @@ -319,7 +612,7 @@ pub mod astro { service_subtypes: &[&str], txt_kvs: &[(&str, &str)], ) -> Result<(), Error> { - AstroMdns::add( + MdnsApi::add( self, name, service, @@ -337,152 +630,19 @@ pub mod astro { protocol: &str, port: u16, ) -> Result<(), Error> { - AstroMdns::remove(self, name, service, protocol, port) + MdnsApi::remove(self, name, service, protocol, port) } } } -// TODO: Maybe future -// #[cfg(all(feature = "std", feature = "zeroconf"))] -// pub mod zeroconf { -// use std::collections::HashMap; - -// use super::Mdns; -// use crate::error::{Error, ErrorCode}; -// use log::info; -// use zeroconf::prelude::*; -// use zeroconf::{MdnsService, ServiceType, TxtRecord}; - -// #[derive(Debug, Clone, Eq, PartialEq, Hash)] -// pub struct ServiceId { -// name: String, -// service: String, -// protocol: String, -// port: u16, -// } - -// pub struct ZeroconfMdns { -// services: HashMap, -// } - -// impl ZeroconfMdns { -// pub fn new() -> Result { -// Ok(Self { -// services: HashMap::new(), -// }) -// } - -// pub fn add( -// &mut self, -// name: &str, -// service: &str, -// protocol: &str, -// port: u16, -// service_subtypes: &[&str], -// txt_kvs: &[(&str, &str)], -// ) -> Result<(), Error> { -// info!( -// "Registering mDNS service {}/{}.{} [{:?}]/{}", -// name, service, protocol, service_subtypes, port -// ); - -// let _ = self.remove(name, service, protocol, port); - -// let mut svc = MdnsService::new( -// ServiceType::with_sub_types(service, protocol, service_subtypes.into()).unwrap(), -// port, -// ); - -// let mut txt = TxtRecord::new(); - -// for kvs in txt_kvs { -// info!("mDNS TXT key {} val {}", kvs.0, kvs.1); -// txt.insert(kvs.0, kvs.1); -// } - -// svc.set_txt_record(txt); - -// //let event_loop = svc.register().map_err(|_| ErrorCode::MdnsError)?; - -// self.services.insert( -// ServiceId { -// name: name.into(), -// service: service.into(), -// protocol: protocol.into(), -// port, -// }, -// svc, -// ); - -// Ok(()) -// } - -// pub fn remove( -// &mut self, -// name: &str, -// service: &str, -// protocol: &str, -// port: u16, -// ) -> Result<(), Error> { -// let id = ServiceId { -// name: name.into(), -// service: service.into(), -// protocol: protocol.into(), -// port, -// }; - -// if self.services.remove(&id).is_some() { -// info!( -// "Deregistering mDNS service {}.{}/{}/{}", -// name, service, protocol, port -// ); -// } - -// Ok(()) -// } -// } - -// impl Mdns for ZeroconfMdns { -// fn add( -// &mut self, -// name: &str, -// service: &str, -// protocol: &str, -// port: u16, -// service_subtypes: &[&str], -// txt_kvs: &[(&str, &str)], -// ) -> Result<(), Error> { -// ZeroconfMdns::add( -// self, -// name, -// service, -// protocol, -// port, -// service_subtypes, -// txt_kvs, -// ) -// } - -// fn remove( -// &mut self, -// name: &str, -// service: &str, -// protocol: &str, -// port: u16, -// ) -> Result<(), Error> { -// ZeroconfMdns::remove(self, name, service, protocol, port) -// } -// } -// } - -#[cfg(all(feature = "std", not(target_os = "espidf")))] -pub mod libmdns { +#[cfg(all(feature = "std", feature = "astro-dnssd"))] +pub mod astro { + use std::collections::HashMap; + use super::Mdns; - use crate::error::Error; - use libmdns::{Responder, Service}; + use crate::error::{Error, ErrorCode}; + use astro_dnssd::{DNSServiceBuilder, RegisteredDnsService}; use log::info; - use std::collections::HashMap; - use std::vec::Vec; #[derive(Debug, Clone, Eq, PartialEq, Hash)] pub struct ServiceId { @@ -492,17 +652,13 @@ pub mod libmdns { port: u16, } - pub struct LibMdns { - responder: Responder, - services: HashMap, + pub struct AstroMdns { + services: HashMap, } - impl LibMdns { + impl AstroMdns { pub fn new() -> Result { - let responder = Responder::new()?; - Ok(Self { - responder, services: HashMap::new(), }) } @@ -513,28 +669,30 @@ pub mod libmdns { service: &str, protocol: &str, port: u16, + service_subtypes: &[&str], txt_kvs: &[(&str, &str)], ) -> Result<(), Error> { info!( - "Registering mDNS service {}/{}.{}/{}", - name, service, protocol, port + "Registering mDNS service {}/{}.{} [{:?}]/{}", + name, service, protocol, service_subtypes, port ); let _ = self.remove(name, service, protocol, port); - let mut properties = Vec::new(); + let composite_service_type = if !service_subtypes.is_empty() { + format!("{}.{},{}", service, protocol, service_subtypes.join(",")) + } else { + format!("{}.{}", service, protocol) + }; + + let mut builder = DNSServiceBuilder::new(&composite_service_type, port).with_name(name); + for kvs in txt_kvs { info!("mDNS TXT key {} val {}", kvs.0, kvs.1); - properties.push(format!("{}={}", kvs.0, kvs.1)); + builder = builder.with_key_value(kvs.0.to_string(), kvs.1.to_string()); } - let properties: Vec<&str> = properties.iter().map(|entry| entry.as_str()).collect(); - let svc = self.responder.register( - format!("{}.{}", service, protocol), - name.to_owned(), - port, - &properties, - ); + let svc = builder.register().map_err(|_| ErrorCode::MdnsError)?; self.services.insert( ServiceId { @@ -574,17 +732,25 @@ pub mod libmdns { } } - impl Mdns for LibMdns { + impl Mdns for AstroMdns { fn add( &mut self, name: &str, service: &str, protocol: &str, port: u16, - _service_subtypes: &[&str], + service_subtypes: &[&str], txt_kvs: &[(&str, &str)], ) -> Result<(), Error> { - LibMdns::add(self, name, service, protocol, port, txt_kvs) + AstroMdns::add( + self, + name, + service, + protocol, + port, + service_subtypes, + txt_kvs, + ) } fn remove( @@ -594,147 +760,11 @@ pub mod libmdns { protocol: &str, port: u16, ) -> Result<(), Error> { - LibMdns::remove(self, name, service, protocol, port) + AstroMdns::remove(self, name, service, protocol, port) } } } -// TODO: Maybe future -// #[cfg(feature = "std")] -// pub mod simplemdns { -// use std::net::Ipv4Addr; - -// use crate::error::{Error, ErrorCode}; -// use super::Mdns; -// use log::info; -// use simple_dns::{ -// rdata::{RData, A, SRV, TXT, PTR}, -// CharacterString, Name, ResourceRecord, CLASS, -// }; -// use simple_mdns::sync_discovery::SimpleMdnsResponder; - -// #[derive(Debug, Clone, Eq, PartialEq, Hash)] -// pub struct ServiceId { -// name: String, -// service_type: String, -// port: u16, -// } - -// pub struct SimpleMdns { -// responder: SimpleMdnsResponder, -// } - -// impl SimpleMdns { -// pub fn new() -> Result { -// Ok(Self { -// responder: Default::default(), -// }) -// } - -// pub fn add( -// &mut self, -// name: &str, -// service_type: &str, -// port: u16, -// txt_kvs: &[(&str, &str)], -// ) -> Result<(), Error> { -// info!( -// "Registering mDNS service {}/{}/{}", -// name, service_type, port -// ); - -// let _ = self.remove(name, service_type, port); - -// let mut txt = TXT::new(); -// for kvs in txt_kvs { -// info!("mDNS TXT key {} val {}", kvs.0, kvs.1); - -// let string = format!("{}={}", kvs.0, kvs.1); -// txt.add_char_string( -// CharacterString::new(string.as_bytes()) -// .unwrap() -// .into_owned(), -// ); -// } - -// let name = Name::new_unchecked(name).into_owned(); -// let service_type = Name::new_unchecked(service_type).into_owned(); - -// self.responder.add_resource(ResourceRecord::new( -// name.clone(), -// CLASS::IN, -// 10, -// RData::A(A { -// address: Ipv4Addr::new(192, 168, 10, 189).into(), -// }), -// )); - -// self.responder.add_resource(ResourceRecord::new( -// name.clone(), -// CLASS::IN, -// 10, -// RData::SRV(SRV { -// port: port, -// priority: 0, -// weight: 0, -// target: service_type.clone(), -// }), -// )); - -// self.responder.add_resource(ResourceRecord::new( -// srv_name.clone(), -// CLASS::IN, -// 10, -// RData::PTR(PTR(srv_name.clone()), -// ))); - -// self.responder.add_resource(ResourceRecord::new( -// srv_name, -// CLASS::IN, -// 10, -// RData::TXT(txt), -// )); - -// Ok(()) -// } - -// pub fn remove(&mut self, name: &str, service_type: &str, port: u16) -> Result<(), Error> { -// // TODO -// // let id = ServiceId { -// // name: name.into(), -// // service_type: service_type.into(), -// // port, -// // }; - -// // if self.responder.remove_resource_record(resource).remove(&id).is_some() { -// // info!( -// // "Deregistering mDNS service {}/{}/{}", -// // name, service_type, port -// // ); -// // } - -// Ok(()) -// } -// } - -// impl Mdns for SimpleMdns { -// fn add( -// &mut self, -// name: &str, -// service_type: &str, -// port: u16, -// _service_subtypes: &[&str], -// txt_kvs: &[(&str, &str)], -// ) -> Result<(), Error> { -// SimpleMdns::add(self, name, service_type, port, txt_kvs) -// } - -// fn remove(&mut self, name: &str, service_type: &str, port: u16) -> Result<(), Error> { -// SimpleMdns::remove(self, name, service_type, port) -// } -// } -// } - #[cfg(test)] mod tests { use super::*; diff --git a/matter/src/pairing/mod.rs b/matter/src/pairing/mod.rs index 2dddce56..253062e0 100644 --- a/matter/src/pairing/mod.rs +++ b/matter/src/pairing/mod.rs @@ -91,7 +91,7 @@ pub fn print_pairing_code_and_qr( let qr_code = compute_qr_code(dev_det, comm_data, discovery_capabilities, buf)?; pretty_print_pairing_code(&pairing_code); - print_qr_code(&qr_code); + print_qr_code(qr_code); Ok(()) } diff --git a/matter/src/secure_channel/case.rs b/matter/src/secure_channel/case.rs index fbd6da8b..c029963a 100644 --- a/matter/src/secure_channel/case.rs +++ b/matter/src/secure_channel/case.rs @@ -35,12 +35,13 @@ use crate::{ utils::{rand::Rand, writebuf::WriteBuf}, }; -#[derive(PartialEq)] +#[derive(Debug, Copy, Clone, Eq, PartialEq)] enum State { Sigma1Rx, Sigma3Rx, } +#[derive(Debug, Clone)] pub struct CaseSession { state: State, peer_sessid: u16, @@ -84,7 +85,7 @@ impl<'a> Case<'a> { let mut case_session = ctx .exch_ctx .exch - .take_case_session::() + .take_case_session() .ok_or(ErrorCode::InvalidState)?; if case_session.state != State::Sigma1Rx { Err(ErrorCode::Invalid)?; diff --git a/matter/src/secure_channel/common.rs b/matter/src/secure_channel/common.rs index 7049ba38..c007ee5f 100644 --- a/matter/src/secure_channel/common.rs +++ b/matter/src/secure_channel/common.rs @@ -56,6 +56,8 @@ pub fn create_sc_status_report( status_code: SCStatusCodes, proto_data: Option<&[u8]>, ) -> Result<(), Error> { + proto_tx.reset(); + let general_code = match status_code { SCStatusCodes::SessionEstablishmentSuccess => GeneralCode::Success, SCStatusCodes::CloseSession => { @@ -79,6 +81,7 @@ pub fn create_sc_status_report( } pub fn create_mrp_standalone_ack(proto_tx: &mut Packet) { + proto_tx.reset(); proto_tx.set_proto_id(PROTO_ID_SECURE_CHANNEL); proto_tx.set_proto_opcode(OpCode::MRPStandAloneAck as u8); proto_tx.unset_reliable(); diff --git a/matter/src/tlv/parser.rs b/matter/src/tlv/parser.rs index b740f5d0..0c179e2a 100644 --- a/matter/src/tlv/parser.rs +++ b/matter/src/tlv/parser.rs @@ -711,11 +711,7 @@ impl<'a> Iterator for TLVContainerIterator<'a> { return None; } - if is_container(element.element_type) { - self.prev_container = true; - } else { - self.prev_container = false; - } + self.prev_container = is_container(element.element_type); Some(element) } } diff --git a/matter/src/tlv/traits.rs b/matter/src/tlv/traits.rs index 3fced12d..28c236be 100644 --- a/matter/src/tlv/traits.rs +++ b/matter/src/tlv/traits.rs @@ -61,6 +61,24 @@ impl<'a, T: FromTLV<'a> + Default, const N: usize> FromTLV<'a> for [T; N] { } } +pub fn from_tlv<'a, T: FromTLV<'a>, const N: usize>( + vec: &mut heapless::Vec, + t: &TLVElement<'a>, +) -> Result<(), Error> { + vec.clear(); + + t.confirm_array()?; + + if let Some(tlv_iter) = t.enter() { + for element in tlv_iter { + vec.push(T::from_tlv(&element)?) + .map_err(|_| ErrorCode::NoSpace)?; + } + } + + Ok(()) +} + macro_rules! fromtlv_for { ($($t:ident)*) => { $( @@ -110,6 +128,16 @@ impl ToTLV for [T; N] { } } +impl<'a, T: ToTLV> ToTLV for &'a [T] { + fn to_tlv(&self, tw: &mut TLVWriter, tag: TagType) -> Result<(), Error> { + tw.start_array(tag)?; + for i in *self { + i.to_tlv(tw, TagType::Anonymous)?; + } + tw.end_container() + } +} + // Generate ToTLV for standard data types totlv_for!(i8 u8 u16 u32 u64 bool); diff --git a/matter/src/transport/exchange.rs b/matter/src/transport/exchange.rs index 57f666c1..04b63db1 100644 --- a/matter/src/transport/exchange.rs +++ b/matter/src/transport/exchange.rs @@ -15,7 +15,6 @@ * limitations under the License. */ -use core::any::Any; use core::fmt; use core::time::Duration; use log::{error, info, trace}; @@ -144,7 +143,7 @@ impl Exchange { } } - pub fn take_case_session(&mut self) -> Option { + pub fn take_case_session(&mut self) -> Option { let old = core::mem::replace(&mut self.data, DataOption::None); if let DataOption::CaseSession(session) = old { Some(session) diff --git a/matter/src/transport/mod.rs b/matter/src/transport/mod.rs index 0b6453ee..1a81c75c 100644 --- a/matter/src/transport/mod.rs +++ b/matter/src/transport/mod.rs @@ -25,5 +25,4 @@ pub mod plain_hdr; pub mod proto_ctx; pub mod proto_hdr; pub mod session; -#[cfg(feature = "std")] pub mod udp; diff --git a/matter/src/transport/network.rs b/matter/src/transport/network.rs index e03658b9..ba50386d 100644 --- a/matter/src/transport/network.rs +++ b/matter/src/transport/network.rs @@ -17,15 +17,23 @@ use core::fmt::{Debug, Display}; #[cfg(not(feature = "std"))] -pub use no_std_net::{IpAddr, Ipv4Addr, SocketAddr}; +pub use no_std_net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr}; #[cfg(feature = "std")] -pub use std::net::{IpAddr, Ipv4Addr, SocketAddr}; +pub use std::net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr}; -#[derive(PartialEq, Copy, Clone)] +#[derive(Eq, PartialEq, Copy, Clone)] pub enum Address { Udp(SocketAddr), } +impl Address { + pub fn unwrap_udp(self) -> SocketAddr { + match self { + Self::Udp(addr) => addr, + } + } +} + impl Default for Address { fn default() -> Self { Address::Udp(SocketAddr::new(IpAddr::V4(Ipv4Addr::new(0, 0, 0, 0)), 8080)) diff --git a/matter/src/transport/packet.rs b/matter/src/transport/packet.rs index 3e7e9c75..72368cb5 100644 --- a/matter/src/transport/packet.rs +++ b/matter/src/transport/packet.rs @@ -31,7 +31,7 @@ use super::{ pub const MAX_RX_BUF_SIZE: usize = 1583; pub const MAX_TX_BUF_SIZE: usize = 1280 - 40/*IPV6 header size*/ - 8/*UDP header size*/; -#[derive(PartialEq)] +#[derive(Debug, PartialEq, Eq, Copy, Clone)] enum RxState { Uninit, PlainDecode, @@ -43,6 +43,30 @@ enum Direction<'a> { Rx(ParseBuf<'a>, RxState), } +impl<'a> Direction<'a> { + pub fn load(&mut self, direction: &Direction) -> Result<(), Error> { + if matches!(self, Self::Tx(_)) != matches!(direction, Direction::Tx(_)) { + Err(ErrorCode::Invalid)?; + } + + match self { + Self::Tx(wb) => match direction { + Direction::Tx(src_wb) => wb.load(src_wb)?, + Direction::Rx(_, _) => Err(ErrorCode::Invalid)?, + }, + Self::Rx(pb, state) => match direction { + Direction::Tx(_) => Err(ErrorCode::Invalid)?, + Direction::Rx(src_pb, src_state) => { + pb.load(src_pb)?; + *state = *src_state; + } + }, + } + + Ok(()) + } +} + pub struct Packet<'a> { pub plain: PlainHdr, pub proto: ProtoHdr, @@ -78,7 +102,7 @@ impl<'a> Packet<'a> { } } - pub fn reset(&mut self) -> () { + pub fn reset(&mut self) { if let Direction::Tx(wb) = &mut self.data { wb.reset(); wb.reserve(Packet::HDR_RESERVE).unwrap(); @@ -91,6 +115,13 @@ impl<'a> Packet<'a> { } } + pub fn load(&mut self, packet: &Packet) -> Result<(), Error> { + self.plain = packet.plain.clone(); + self.proto = packet.proto.clone(); + self.peer = packet.peer; + self.data.load(&packet.data) + } + pub fn as_slice(&self) -> &[u8] { match &self.data { Direction::Rx(pb, _) => pb.as_slice(), diff --git a/matter/src/transport/plain_hdr.rs b/matter/src/transport/plain_hdr.rs index e5a9b24e..5a0728af 100644 --- a/matter/src/transport/plain_hdr.rs +++ b/matter/src/transport/plain_hdr.rs @@ -21,7 +21,7 @@ use crate::utils::writebuf::WriteBuf; use bitflags::bitflags; use log::info; -#[derive(Debug, PartialEq, Default)] +#[derive(Debug, PartialEq, Eq, Default, Copy, Clone)] pub enum SessionType { #[default] None, @@ -38,7 +38,7 @@ bitflags! { } // This is the unencrypted message -#[derive(Debug, Default)] +#[derive(Debug, Default, Clone)] pub struct PlainHdr { pub flags: MsgFlags, pub sess_type: SessionType, diff --git a/matter/src/transport/proto_hdr.rs b/matter/src/transport/proto_hdr.rs index d7f92fb4..9bf80d43 100644 --- a/matter/src/transport/proto_hdr.rs +++ b/matter/src/transport/proto_hdr.rs @@ -36,7 +36,7 @@ bitflags! { } } -#[derive(Default)] +#[derive(Debug, Default, Clone)] pub struct ProtoHdr { pub exch_id: u16, pub exch_flags: ExchFlags, @@ -278,7 +278,7 @@ mod tests { decrypt_in_place(recvd_ctr, 0, &mut parsebuf, &key).unwrap(); assert_eq!( - parsebuf.into_slice(), + parsebuf.as_slice(), [ 0x5, 0x8, 0x70, 0x0, 0x1, 0x0, 0x15, 0x28, 0x0, 0x28, 0x1, 0x36, 0x2, 0x15, 0x37, 0x0, 0x24, 0x0, 0x0, 0x24, 0x1, 0x30, 0x24, 0x2, 0x2, 0x18, 0x35, 0x1, 0x24, 0x0, diff --git a/matter/src/transport/session.rs b/matter/src/transport/session.rs index 1135f05d..1e3a1d4f 100644 --- a/matter/src/transport/session.rs +++ b/matter/src/transport/session.rs @@ -19,11 +19,8 @@ use crate::data_model::sdm::noc::NocData; use crate::utils::epoch::Epoch; use crate::utils::rand::Rand; use core::fmt; +use core::ops::{Deref, DerefMut}; use core::time::Duration; -use core::{ - any::Any, - ops::{Deref, DerefMut}, -}; use crate::{ error::*, @@ -166,7 +163,7 @@ impl Session { self.data = None; } - pub fn get_noc_data(&mut self) -> Option<&mut NocData> { + pub fn get_noc_data(&mut self) -> Option<&mut NocData> { self.data.as_mut() } @@ -325,17 +322,16 @@ pub const MAX_SESSIONS: usize = 16; pub struct SessionMgr { next_sess_id: u16, - sessions: [Option; MAX_SESSIONS], + sessions: heapless::Vec, MAX_SESSIONS>, epoch: Epoch, rand: Rand, } impl SessionMgr { + #[inline(always)] pub fn new(epoch: Epoch, rand: Rand) -> Self { - const INIT: Option = None; - Self { - sessions: [INIT; MAX_SESSIONS], + sessions: heapless::Vec::new(), next_sess_id: 1, epoch, rand, @@ -343,10 +339,10 @@ impl SessionMgr { } pub fn mut_by_index(&mut self, index: usize) -> Option<&mut Session> { - self.sessions[index].as_mut() + self.sessions.get_mut(index).and_then(Option::as_mut) } - fn get_next_sess_id(&mut self) -> u16 { + pub fn get_next_sess_id(&mut self) -> u16 { let mut next_sess_id: u16; loop { next_sess_id = self.next_sess_id; @@ -366,7 +362,7 @@ impl SessionMgr { } pub fn get_session_for_eviction(&self) -> Option { - if self.get_empty_slot().is_none() { + if self.sessions.len() == MAX_SESSIONS && self.get_empty_slot().is_none() { Some(self.get_lru()) } else { None @@ -380,8 +376,8 @@ impl SessionMgr { fn get_lru(&self) -> usize { let mut lru_index = 0; let mut lru_ts = (self.epoch)(); - for i in 0..MAX_SESSIONS { - if let Some(s) = &self.sessions[i] { + for (i, s) in self.sessions.iter().enumerate() { + if let Some(s) = s { if s.last_use < lru_ts { lru_ts = s.last_use; lru_index = i; @@ -405,10 +401,17 @@ impl SessionMgr { /// We could have returned a SessionHandle here. But the borrow checker doesn't support /// non-lexical lifetimes. This makes it harder for the caller of this function to take /// action in the error return path - pub fn add_session(&mut self, session: Session) -> Result { + fn add_session(&mut self, session: Session) -> Result { if let Some(index) = self.get_empty_slot() { self.sessions[index] = Some(session); Ok(index) + } else if self.sessions.len() < MAX_SESSIONS { + self.sessions + .push(Some(session)) + .map_err(|_| ErrorCode::NoSpace) + .unwrap(); + + Ok(self.sessions.len() - 1) } else { Err(ErrorCode::NoSpace.into()) } @@ -419,7 +422,7 @@ impl SessionMgr { self.add_session(session) } - fn _get( + pub fn get( &self, sess_id: u16, peer_addr: Address, @@ -451,14 +454,14 @@ impl SessionMgr { Some(self.get_session_handle(index)) } - pub fn get_or_add( + fn get_or_add( &mut self, sess_id: u16, peer_addr: Address, peer_nodeid: Option, is_encrypted: bool, ) -> Result { - if let Some(index) = self._get(sess_id, peer_addr, peer_nodeid, is_encrypted) { + if let Some(index) = self.get(sess_id, peer_addr, peer_nodeid, is_encrypted) { Ok(index) } else if sess_id == 0 && !is_encrypted { // We must create a new session for this case @@ -538,7 +541,7 @@ impl fmt::Display for SessionMgr { } pub struct SessionHandle<'a> { - sess_mgr: &'a mut SessionMgr, + pub(crate) sess_mgr: &'a mut SessionMgr, sess_idx: usize, } diff --git a/matter/src/transport/udp.rs b/matter/src/transport/udp.rs index 909ab1ec..7cf52889 100644 --- a/matter/src/transport/udp.rs +++ b/matter/src/transport/udp.rs @@ -15,64 +15,103 @@ * limitations under the License. */ -use crate::{error::*, MATTER_PORT}; -use log::{info, warn}; -use smol::net::{Ipv6Addr, UdpSocket}; +#[cfg(feature = "std")] +pub use smol_udp::*; -use super::network::Address; +#[cfg(not(feature = "std"))] +pub use dummy_udp::*; -// We could get rid of the smol here, but keeping it around in case we have to process -// any other events in this thread's context -pub struct UdpListener { - socket: UdpSocket, -} - -impl UdpListener { - pub async fn new() -> Result { - let listener = UdpListener { - socket: UdpSocket::bind((Ipv6Addr::UNSPECIFIED, MATTER_PORT)).await?, - }; +#[cfg(feature = "std")] +mod smol_udp { + use crate::error::*; + use log::{debug, info, warn}; + use smol::net::UdpSocket; - info!( - "Listening on {:?} port {}", - Ipv6Addr::UNSPECIFIED, - MATTER_PORT - ); + use crate::transport::network::SocketAddr; - Ok(listener) + pub struct UdpListener { + socket: UdpSocket, } - pub async fn recv(&self, in_buf: &mut [u8]) -> Result<(usize, Address), Error> { - info!("Waiting for incoming packets"); + impl UdpListener { + pub async fn new(addr: SocketAddr) -> Result { + let listener = UdpListener { + socket: UdpSocket::bind((addr.ip(), addr.port())).await?, + }; + + info!("Listening on {:?}", addr); + + Ok(listener) + } + + pub async fn recv(&self, in_buf: &mut [u8]) -> Result<(usize, SocketAddr), Error> { + info!("Waiting for incoming packets"); - let (size, addr) = self.socket.recv_from(in_buf).await.map_err(|e| { - warn!("Error on the network: {:?}", e); - ErrorCode::Network - })?; + let (size, addr) = self.socket.recv_from(in_buf).await.map_err(|e| { + warn!("Error on the network: {:?}", e); + ErrorCode::Network + })?; - info!("Got packet: {:?} from addr {:?}", &in_buf[..size], addr); + debug!("Got packet {:?} from addr {:?}", &in_buf[..size], addr); - Ok((size, Address::Udp(addr))) + Ok((size, addr)) + } + + pub async fn send(&self, addr: SocketAddr, out_buf: &[u8]) -> Result { + let len = self.socket.send_to(out_buf, addr).await.map_err(|e| { + warn!("Error on the network: {:?}", e); + ErrorCode::Network + })?; + + debug!( + "Send packet {:?} ({}/{}) to addr {:?}", + out_buf, + out_buf.len(), + len, + addr + ); + + Ok(len) + } } +} + +#[cfg(not(feature = "std"))] +mod dummy_udp { + use core::future::pending; + + use crate::error::*; + use log::{debug, info}; + + use crate::transport::network::SocketAddr; + + pub struct UdpListener {} + + impl UdpListener { + pub async fn new(addr: SocketAddr) -> Result { + let listener = UdpListener {}; + + info!("Pretending to listen on {:?}", addr); + + Ok(listener) + } + + pub async fn recv(&self, _in_buf: &mut [u8]) -> Result<(usize, SocketAddr), Error> { + info!("Pretending to wait for incoming packets (looping forever)"); + + pending().await + } + + pub async fn send(&self, addr: SocketAddr, out_buf: &[u8]) -> Result { + debug!( + "Send packet {:?} ({}/{}) to addr {:?}", + out_buf, + out_buf.len(), + out_buf.len(), + addr + ); - pub async fn send(&self, addr: Address, out_buf: &[u8]) -> Result { - match addr { - Address::Udp(addr) => { - let len = self.socket.send_to(out_buf, addr).await.map_err(|e| { - warn!("Error on the network: {:?}", e); - ErrorCode::Network - })?; - - info!( - "Send packet: {:?} ({}/{}) to addr {:?}", - out_buf, - out_buf.len(), - len, - addr - ); - - Ok(len) - } + Ok(out_buf.len()) } } } diff --git a/matter/src/utils/mod.rs b/matter/src/utils/mod.rs index 1e69b847..5a3fe811 100644 --- a/matter/src/utils/mod.rs +++ b/matter/src/utils/mod.rs @@ -18,4 +18,5 @@ pub mod epoch; pub mod parsebuf; pub mod rand; +pub mod select; pub mod writebuf; diff --git a/matter/src/utils/parsebuf.rs b/matter/src/utils/parsebuf.rs index 549e022b..233693cd 100644 --- a/matter/src/utils/parsebuf.rs +++ b/matter/src/utils/parsebuf.rs @@ -35,13 +35,25 @@ impl<'a> ParseBuf<'a> { } } - pub fn set_len(&mut self, left: usize) { - self.left = left; + pub fn reset(&mut self) { + self.read_off = 0; + self.left = self.buf.len(); } - // Return the data that is valid as a slice, consume self - pub fn into_slice(self) -> &'a mut [u8] { - &mut self.buf[self.read_off..(self.read_off + self.left)] + pub fn load(&mut self, pb: &ParseBuf) -> Result<(), Error> { + if self.buf.len() < pb.read_off + pb.left { + Err(ErrorCode::NoSpace)?; + } + + self.buf[0..pb.read_off + pb.left].copy_from_slice(&pb.buf[..pb.read_off + pb.left]); + self.read_off = pb.read_off; + self.left = pb.left; + + Ok(()) + } + + pub fn set_len(&mut self, left: usize) { + self.left = left; } // Return the data that is valid as a slice @@ -114,7 +126,7 @@ mod tests { assert_eq!(buf.le_u8().unwrap(), 0x01); assert_eq!(buf.le_u16().unwrap(), 65); assert_eq!(buf.le_u32().unwrap(), 0xcafebabe); - assert_eq!(buf.into_slice(), [0xa, 0xb, 0xc, 0xd]); + assert_eq!(buf.as_slice(), [0xa, 0xb, 0xc, 0xd]); } #[test] @@ -138,7 +150,7 @@ mod tests { if buf.le_u8().is_ok() { panic!("This should have returned error") } - assert_eq!(buf.into_slice(), []); + assert_eq!(buf.as_slice(), [] as [u8; 0]); } #[test] @@ -154,7 +166,7 @@ mod tests { assert_eq!(buf.as_mut_slice(), [0xa, 0xb]); assert_eq!(buf.tail(2).unwrap(), [0xa, 0xb]); - assert_eq!(buf.into_slice(), []); + assert_eq!(buf.as_slice(), [] as [u8; 0]); } #[test] @@ -176,7 +188,7 @@ mod tests { let mut test_slice = [0x01, 65, 0, 0xbe, 0xba, 0xfe, 0xca, 0xa, 0xb, 0xc, 0xd]; let mut buf = ParseBuf::new(&mut test_slice); - assert_eq!(buf.parsed_as_slice(), []); + assert_eq!(buf.parsed_as_slice(), [] as [u8; 0]); assert_eq!(buf.le_u8().unwrap(), 0x1); assert_eq!(buf.le_u16().unwrap(), 65); assert_eq!(buf.le_u32().unwrap(), 0xcafebabe); diff --git a/matter/src/utils/select.rs b/matter/src/utils/select.rs new file mode 100644 index 00000000..2b5d21e9 --- /dev/null +++ b/matter/src/utils/select.rs @@ -0,0 +1,35 @@ +use embassy_futures::select::{Either, Either3, Either4}; + +pub trait EitherUnwrap { + fn unwrap(self) -> T; +} + +impl EitherUnwrap for Either { + fn unwrap(self) -> T { + match self { + Self::First(t) => t, + Self::Second(t) => t, + } + } +} + +impl EitherUnwrap for Either3 { + fn unwrap(self) -> T { + match self { + Self::First(t) => t, + Self::Second(t) => t, + Self::Third(t) => t, + } + } +} + +impl EitherUnwrap for Either4 { + fn unwrap(self) -> T { + match self { + Self::First(t) => t, + Self::Second(t) => t, + Self::Third(t) => t, + Self::Fourth(t) => t, + } + } +} diff --git a/matter/src/utils/writebuf.rs b/matter/src/utils/writebuf.rs index 21a51e28..2f24c977 100644 --- a/matter/src/utils/writebuf.rs +++ b/matter/src/utils/writebuf.rs @@ -68,6 +68,18 @@ impl<'a> WriteBuf<'a> { self.end = 0; } + pub fn load(&mut self, wb: &WriteBuf) -> Result<(), Error> { + if self.buf_size < wb.end { + Err(ErrorCode::NoSpace)?; + } + + self.buf[0..wb.end].copy_from_slice(&wb.buf[..wb.end]); + self.start = wb.start; + self.end = wb.end; + + Ok(()) + } + pub fn reserve(&mut self, reserve: usize) -> Result<(), Error> { if self.end != 0 || self.start != 0 || self.buf_size != self.buf.len() { Err(ErrorCode::Invalid.into()) From 357eb73c6fb3256039716868ce1a9e3fbec015c5 Mon Sep 17 00:00:00 2001 From: ivmarkov Date: Sun, 28 May 2023 11:04:46 +0000 Subject: [PATCH 45/72] Control memory by removing implicit copy --- matter/src/acl.rs | 20 +- matter/src/data_model/objects/encoder.rs | 10 +- matter/src/data_model/objects/node.rs | 18 +- .../data_model/sdm/general_commissioning.rs | 7 +- .../data_model/system_model/access_control.rs | 10 +- matter/src/interaction_model/messages.rs | 38 ++-- matter/src/mdns.rs | 15 +- matter/src/tlv/parser.rs | 155 +++++---------- matter/src/tlv/traits.rs | 6 +- matter/src/transport/mrp.rs | 6 +- matter/src/transport/session.rs | 14 +- matter/tests/common/attributes.rs | 4 +- matter/tests/common/commands.rs | 12 +- matter/tests/data_model/acl_and_dataver.rs | 4 +- matter/tests/data_model/attribute_lists.rs | 22 ++- matter/tests/data_model/attributes.rs | 4 +- matter/tests/data_model/commands.rs | 6 +- matter/tests/data_model/long_reads.rs | 183 ++++++++++++------ matter_macro_derive/src/lib.rs | 2 +- 19 files changed, 281 insertions(+), 255 deletions(-) diff --git a/matter/src/acl.rs b/matter/src/acl.rs index 8bd8b701..15e4def5 100644 --- a/matter/src/acl.rs +++ b/matter/src/acl.rs @@ -260,7 +260,7 @@ impl<'a> AccessReq<'a> { } } -#[derive(FromTLV, ToTLV, Copy, Clone, Debug, PartialEq)] +#[derive(FromTLV, ToTLV, Clone, Debug, PartialEq)] pub struct Target { cluster: Option, endpoint: Option, @@ -283,7 +283,7 @@ impl Target { type Subjects = [Option; SUBJECTS_PER_ENTRY]; type Targets = [Option; TARGETS_PER_ENTRY]; -#[derive(ToTLV, FromTLV, Copy, Clone, Debug, PartialEq)] +#[derive(ToTLV, FromTLV, Clone, Debug, PartialEq)] #[tlvargs(start = 1)] pub struct AclEntry { privilege: Privilege, @@ -463,7 +463,11 @@ impl AclMgr { pub fn delete_for_fabric(&mut self, fab_idx: u8) -> Result<(), Error> { for entry in &mut self.entries { - if entry.map(|e| e.fab_idx == Some(fab_idx)).unwrap_or(false) { + if entry + .as_ref() + .map(|e| e.fab_idx == Some(fab_idx)) + .unwrap_or(false) + { *entry = None; self.changed = true; } @@ -545,7 +549,11 @@ impl AclMgr { for (curr_index, entry) in self .entries .iter_mut() - .filter(|e| e.filter(|e1| e1.fab_idx == Some(fab_idx)).is_some()) + .filter(|e| { + e.as_ref() + .filter(|e1| e1.fab_idx == Some(fab_idx)) + .is_some() + }) .enumerate() { if curr_index == index as usize { @@ -779,7 +787,7 @@ mod tests { am.borrow_mut().add(new).unwrap(); // Write on an RWVA without admin access - deny - let mut req = AccessReq::new(&accessor, path, Access::WRITE); + let mut req = AccessReq::new(&accessor, path.clone(), Access::WRITE); req.set_target_perms(Access::RWVA); assert_eq!(req.allow(), false); @@ -806,7 +814,7 @@ mod tests { am.borrow_mut().erase_all().unwrap(); let path = GenericPath::new(Some(1), Some(1234), None); let accessor2 = Accessor::new(2, AccessorSubjects::new(112233), AuthMode::Case, &am); - let mut req2 = AccessReq::new(&accessor2, path, Access::READ); + let mut req2 = AccessReq::new(&accessor2, path.clone(), Access::READ); req2.set_target_perms(Access::RWVA); let accessor3 = Accessor::new(3, AccessorSubjects::new(112233), AuthMode::Case, &am); let mut req3 = AccessReq::new(&accessor3, path, Access::READ); diff --git a/matter/src/data_model/objects/encoder.rs b/matter/src/data_model/objects/encoder.rs index e97eea0e..70e0db76 100644 --- a/matter/src/data_model/objects/encoder.rs +++ b/matter/src/data_model/objects/encoder.rs @@ -39,7 +39,7 @@ use super::{AttrDetails, CmdDetails, Handler}; // the tw.rewind() in that case, if we add this support pub type EncodeValueGen<'a> = &'a dyn Fn(TagType, &mut TLVWriter); -#[derive(Copy, Clone)] +#[derive(Clone)] /// A structure for encoding various types of values pub enum EncodeValue<'a> { /// This indicates a value that is dynamically generated. This variant @@ -66,13 +66,13 @@ impl<'a> EncodeValue<'a> { impl<'a> PartialEq for EncodeValue<'a> { fn eq(&self, other: &Self) -> bool { - match *self { + match self { EncodeValue::Closure(_) => { error!("PartialEq not yet supported"); false } EncodeValue::Tlv(a) => { - if let EncodeValue::Tlv(b) = *other { + if let EncodeValue::Tlv(b) = other { a == b } else { false @@ -89,7 +89,7 @@ impl<'a> PartialEq for EncodeValue<'a> { impl<'a> Debug for EncodeValue<'a> { fn fmt(&self, f: &mut Formatter<'_>) -> Result<(), core::fmt::Error> { - match *self { + match self { EncodeValue::Closure(_) => write!(f, "Contains closure"), EncodeValue::Tlv(t) => write!(f, "{:?}", t), EncodeValue::Value(_) => write!(f, "Contains EncodeValue"), @@ -113,7 +113,7 @@ impl<'a> ToTLV for EncodeValue<'a> { impl<'a> FromTLV<'a> for EncodeValue<'a> { fn from_tlv(data: &TLVElement<'a>) -> Result { - Ok(EncodeValue::Tlv(*data)) + Ok(EncodeValue::Tlv(data.clone())) } } diff --git a/matter/src/data_model/objects/node.rs b/matter/src/data_model/objects/node.rs index 3ee3af27..41720b61 100644 --- a/matter/src/data_model/objects/node.rs +++ b/matter/src/data_model/objects/node.rs @@ -84,11 +84,11 @@ impl<'a> Iterable for Option<&'a TLVArray<'a, DataVersionFilter>> { impl<'a> Iterable for &'a [DataVersionFilter] { type Item = DataVersionFilter; - type Iterator<'i> = core::iter::Copied> where Self: 'i; + type Iterator<'i> = core::iter::Cloned> where Self: 'i; fn iter(&self) -> Self::Iterator<'_> { let slice: &[DataVersionFilter] = self; - slice.iter().copied() + slice.iter().cloned() } } @@ -127,11 +127,11 @@ impl<'a> Node<'a> { 's: 'm, { self.read_attr_requests( - req.paths.iter().copied(), + req.paths.iter().cloned(), req.filters.as_slice(), req.fabric_filtered, accessor, - Some(req.resume_path), + Some(req.resume_path.clone()), ) } @@ -163,11 +163,11 @@ impl<'a> Node<'a> { 's: 'm, { self.read_attr_requests( - req.paths.iter().copied(), + req.paths.iter().cloned(), req.filters.as_slice(), req.fabric_filtered, accessor, - Some(req.resume_path.unwrap()), + Some(req.resume_path.clone().unwrap()), ) } @@ -187,7 +187,7 @@ impl<'a> Node<'a> { attr_requests.flat_map(move |path| { if path.to_gp().is_wildcard() { let dataver_filters = dataver_filters.clone(); - let from = from; + let from = from.clone(); let iter = self .match_attributes(path.endpoint, path.cluster, path.attr) @@ -302,7 +302,7 @@ impl<'a> Node<'a> { dataver: attr_data.data_ver, wildcard: true, }, - attr_data.data.unwrap_tlv().unwrap(), + attr_data.data.clone().unwrap_tlv().unwrap(), )) }); @@ -367,7 +367,7 @@ impl<'a> Node<'a> { cmd_id: cmd, wildcard: true, }, - cmd_data.data.unwrap_tlv().unwrap(), + cmd_data.data.clone().unwrap_tlv().unwrap(), )) }); diff --git a/matter/src/data_model/sdm/general_commissioning.rs b/matter/src/data_model/sdm/general_commissioning.rs index b0cdff11..78c3bef3 100644 --- a/matter/src/data_model/sdm/general_commissioning.rs +++ b/matter/src/data_model/sdm/general_commissioning.rs @@ -205,7 +205,10 @@ impl<'a> GenCommCluster<'a> { let status = if self .failsafe .borrow_mut() - .arm(p.expiry_len, transaction.session().get_session_mode()) + .arm( + p.expiry_len, + transaction.session().get_session_mode().clone(), + ) .is_err() { CommissioningError::ErrBusyWithOtherAdmin as u8 @@ -271,7 +274,7 @@ impl<'a> GenCommCluster<'a> { if self .failsafe .borrow_mut() - .disarm(transaction.session().get_session_mode()) + .disarm(transaction.session().get_session_mode().clone()) .is_err() { status = CommissioningError::ErrInvalidAuth as u8; diff --git a/matter/src/data_model/system_model/access_control.rs b/matter/src/data_model/system_model/access_control.rs index c57c0df2..17c88e33 100644 --- a/matter/src/data_model/system_model/access_control.rs +++ b/matter/src/data_model/system_model/access_control.rs @@ -255,8 +255,8 @@ mod tests { AclEntry::new(1, Privilege::VIEW, AuthMode::Case), AclEntry::new(2, Privilege::ADMIN, AuthMode::Case), ]; - for i in verifier { - acl_mgr.borrow_mut().add(i).unwrap(); + for i in &verifier { + acl_mgr.borrow_mut().add(i.clone()).unwrap(); } let mut acl = AccessControlCluster::new(&acl_mgr, dummy_rand); @@ -292,8 +292,8 @@ mod tests { AclEntry::new(1, Privilege::VIEW, AuthMode::Case), AclEntry::new(2, Privilege::ADMIN, AuthMode::Case), ]; - for i in input { - acl_mgr.borrow_mut().add(i).unwrap(); + for i in &input { + acl_mgr.borrow_mut().add(i.clone()).unwrap(); } let mut acl = AccessControlCluster::new(&acl_mgr, dummy_rand); // data is don't-care actually @@ -303,7 +303,7 @@ mod tests { let result = acl.write_acl_attr(&ListOperation::DeleteItem(0), &data, 1); assert!(result.is_ok()); - let verifier = [input[0], input[2]]; + let verifier = [input[0].clone(), input[2].clone()]; // Also validate in the acl_mgr that the entries are in the right order let mut index = 0; acl_mgr diff --git a/matter/src/interaction_model/messages.rs b/matter/src/interaction_model/messages.rs index bfd8a8b4..0cf859a9 100644 --- a/matter/src/interaction_model/messages.rs +++ b/matter/src/interaction_model/messages.rs @@ -23,7 +23,7 @@ use crate::{ // A generic path with endpoint, clusters, and a leaf // The leaf could be command, attribute, event -#[derive(Default, Clone, Copy, Debug, PartialEq, FromTLV, ToTLV)] +#[derive(Default, Clone, Debug, PartialEq, FromTLV, ToTLV)] #[tlvargs(datatype = "list")] pub struct GenericPath { pub endpoint: Option, @@ -106,16 +106,6 @@ pub mod msg { self.attr_requests = Some(TLVArray::new(requests)); self } - - pub fn to_read_req(&self) -> ReadReq<'a> { - ReadReq { - attr_requests: self.attr_requests, - event_requests: self.event_requests, - event_filters: self.event_filters, - fabric_filtered: self.fabric_filtered, - dataver_filters: self.dataver_filters, - } - } } #[derive(Debug, FromTLV, ToTLV)] @@ -268,7 +258,7 @@ pub mod ib { use super::GenericPath; // Command Response - #[derive(Clone, Copy, FromTLV, ToTLV, Debug)] + #[derive(Clone, FromTLV, ToTLV, Debug)] #[tlvargs(lifetime = "'a")] pub enum InvResp<'a> { Cmd(CmdData<'a>), @@ -301,7 +291,7 @@ pub mod ib { } } - #[derive(FromTLV, ToTLV, Copy, Clone, PartialEq, Debug)] + #[derive(FromTLV, ToTLV, Clone, PartialEq, Debug)] pub struct CmdStatus { path: CmdPath, status: Status, @@ -319,7 +309,7 @@ pub mod ib { } } - #[derive(Debug, Clone, Copy, FromTLV, ToTLV)] + #[derive(Debug, Clone, FromTLV, ToTLV)] #[tlvargs(lifetime = "'a")] pub struct CmdData<'a> { pub path: CmdPath, @@ -338,7 +328,7 @@ pub mod ib { } // Status - #[derive(Debug, Clone, Copy, PartialEq, FromTLV, ToTLV)] + #[derive(Debug, Clone, PartialEq, FromTLV, ToTLV)] pub struct Status { pub status: IMStatusCode, pub cluster_status: u16, @@ -354,7 +344,7 @@ pub mod ib { } // Attribute Response - #[derive(Clone, Copy, FromTLV, ToTLV, PartialEq, Debug)] + #[derive(Clone, FromTLV, ToTLV, PartialEq, Debug)] #[tlvargs(lifetime = "'a")] pub enum AttrResp<'a> { Status(AttrStatus), @@ -390,7 +380,7 @@ pub mod ib { } // Attribute Data - #[derive(Clone, Copy, PartialEq, FromTLV, ToTLV, Debug)] + #[derive(Clone, PartialEq, FromTLV, ToTLV, Debug)] #[tlvargs(lifetime = "'a")] pub struct AttrData<'a> { pub data_ver: Option, @@ -458,7 +448,7 @@ pub mod ib { } } - #[derive(Debug, Clone, Copy, PartialEq, FromTLV, ToTLV)] + #[derive(Debug, Clone, PartialEq, FromTLV, ToTLV)] pub struct AttrStatus { path: AttrPath, status: Status, @@ -474,7 +464,7 @@ pub mod ib { } // Attribute Path - #[derive(Default, Clone, Copy, Debug, PartialEq, FromTLV, ToTLV)] + #[derive(Default, Clone, Debug, PartialEq, FromTLV, ToTLV)] #[tlvargs(datatype = "list")] pub struct AttrPath { pub tag_compression: Option, @@ -501,7 +491,7 @@ pub mod ib { } // Command Path - #[derive(Default, Debug, Copy, Clone, PartialEq)] + #[derive(Default, Debug, Clone, PartialEq)] pub struct CmdPath { pub path: GenericPath, } @@ -557,20 +547,20 @@ pub mod ib { } } - #[derive(FromTLV, ToTLV, Copy, Clone, Debug)] + #[derive(FromTLV, ToTLV, Clone, Debug)] pub struct ClusterPath { pub node: Option, pub endpoint: EndptId, pub cluster: ClusterId, } - #[derive(FromTLV, ToTLV, Copy, Clone, Debug)] + #[derive(FromTLV, ToTLV, Clone, Debug)] pub struct DataVersionFilter { pub path: ClusterPath, pub data_ver: u32, } - #[derive(FromTLV, ToTLV, Copy, Clone, Debug)] + #[derive(FromTLV, ToTLV, Clone, Debug)] #[tlvargs(datatype = "list")] pub struct EventPath { pub node: Option, @@ -580,7 +570,7 @@ pub mod ib { pub is_urgent: Option, } - #[derive(FromTLV, ToTLV, Copy, Clone, Debug)] + #[derive(FromTLV, ToTLV, Clone, Debug)] pub struct EventFilter { pub node: Option, pub event_min: Option, diff --git a/matter/src/mdns.rs b/matter/src/mdns.rs index a1876833..80333b14 100644 --- a/matter/src/mdns.rs +++ b/matter/src/mdns.rs @@ -234,15 +234,15 @@ pub mod builtin { use crate::transport::udp::UdpListener; use crate::utils::select::EitherUnwrap; - const IP_BROADCAST_ADDRS: [SocketAddr; 2] = [ - SocketAddr::new(IpAddr::V4(Ipv4Addr::new(224, 0, 0, 251)), 5353), - SocketAddr::new( + const IP_BROADCAST_ADDRS: [(IpAddr, u16); 2] = [ + (IpAddr::V4(Ipv4Addr::new(224, 0, 0, 251)), 5353), + ( IpAddr::V6(Ipv6Addr::new(0xff02, 0, 0, 0, 0, 0, 0, 0x00fb)), 5353, ), ]; - const IP_BIND_ADDR: SocketAddr = SocketAddr::new(IpAddr::V6(Ipv6Addr::UNSPECIFIED), 5353); + const IP_BIND_ADDR: (IpAddr, u16) = (IpAddr::V6(Ipv6Addr::UNSPECIFIED), 5353); pub fn create_record( id: u16, @@ -429,7 +429,8 @@ pub mod builtin { async fn bind(&self) -> Result<(), Error> { if self.udp.borrow().is_none() { - *self.udp.borrow_mut() = Some(UdpListener::new(IP_BIND_ADDR).await?); + *self.udp.borrow_mut() = + Some(UdpListener::new(SocketAddr::new(IP_BIND_ADDR.0, IP_BIND_ADDR.1)).await?); } Ok(()) @@ -575,8 +576,8 @@ pub mod builtin { let udp = self.0.udp.borrow(); let udp = udp.as_ref().unwrap(); - for addr in IP_BROADCAST_ADDRS { - udp.send(addr, &entry.record).await?; + for (addr, port) in IP_BROADCAST_ADDRS { + udp.send(SocketAddr::new(addr, port), &entry.record).await?; } index += 1; diff --git a/matter/src/tlv/parser.rs b/matter/src/tlv/parser.rs index 0c179e2a..8bfdd28a 100644 --- a/matter/src/tlv/parser.rs +++ b/matter/src/tlv/parser.rs @@ -33,14 +33,7 @@ impl<'a> TLVList<'a> { } } -#[derive(Debug, Copy, Clone, PartialEq)] -pub struct Pointer<'a> { - buf: &'a [u8], - current: usize, - left: usize, -} - -#[derive(Debug, Copy, Clone, PartialEq)] +#[derive(Debug, Clone, PartialEq)] pub enum ElementType<'a> { S8(i8), S16(i16), @@ -63,9 +56,9 @@ pub enum ElementType<'a> { Str32l, Str64l, Null, - Struct(Pointer<'a>), - Array(Pointer<'a>), - List(Pointer<'a>), + Struct(&'a [u8]), + Array(&'a [u8]), + List(&'a [u8]), EndCnt, Last, } @@ -204,44 +197,11 @@ static VALUE_EXTRACTOR: [ExtractValue; MAX_VALUE_INDEX] = [ // Null 20 { |_t| (0, ElementType::Null) }, // Struct 21 - { - |t| { - ( - 0, - ElementType::Struct(Pointer { - buf: t.buf, - current: t.current, - left: t.left, - }), - ) - } - }, + { |t| (0, ElementType::Struct(&t.buf[t.current..])) }, // Array 22 - { - |t| { - ( - 0, - ElementType::Array(Pointer { - buf: t.buf, - current: t.current, - left: t.left, - }), - ) - } - }, + { |t| (0, ElementType::Array(&t.buf[t.current..])) }, // List 23 - { - |t| { - ( - 0, - ElementType::List(Pointer { - buf: t.buf, - current: t.current, - left: t.left, - }), - ) - } - }, + { |t| (0, ElementType::List(&t.buf[t.current..])) }, // EndCnt 24 { |_t| (0, ElementType::EndCnt) }, ]; @@ -282,7 +242,7 @@ fn read_length_value<'a>( // The current offset is the string size let length: usize = LittleEndian::read_uint(&t.buf[t.current..], size_of_length_field) as usize; // We'll consume the current offset (len) + the entire string - if length + size_of_length_field > t.left { + if length + size_of_length_field > t.buf.len() - t.current { // Return Error Err(ErrorCode::NoSpace.into()) } else { @@ -294,7 +254,7 @@ fn read_length_value<'a>( } } -#[derive(Debug, Copy, Clone)] +#[derive(Debug, Clone)] pub struct TLVElement<'a> { tag_type: TagType, element_type: ElementType<'a>, @@ -303,11 +263,11 @@ pub struct TLVElement<'a> { impl<'a> PartialEq for TLVElement<'a> { fn eq(&self, other: &Self) -> bool { match self.element_type { - ElementType::Struct(a) | ElementType::Array(a) | ElementType::List(a) => { - let mut our_iter = TLVListIterator::from_pointer(a); + ElementType::Struct(buf) | ElementType::Array(buf) | ElementType::List(buf) => { + let mut our_iter = TLVListIterator::from_buf(buf); let mut their = match other.element_type { - ElementType::Struct(b) | ElementType::Array(b) | ElementType::List(b) => { - TLVListIterator::from_pointer(b) + ElementType::Struct(buf) | ElementType::Array(buf) | ElementType::List(buf) => { + TLVListIterator::from_buf(buf) } _ => { // If we are a container, the other must be a container, else this is a mismatch @@ -336,7 +296,7 @@ impl<'a> PartialEq for TLVElement<'a> { } nest_level -= 1; } else { - if is_container(ours.element_type) { + if is_container(&ours.element_type) { nest_level += 1; // Only compare the discriminants in case of array/list/structures, // instead of actual element values. Those will be subsets within this same @@ -364,15 +324,11 @@ impl<'a> PartialEq for TLVElement<'a> { impl<'a> TLVElement<'a> { pub fn enter(&self) -> Option> { - let ptr = match self.element_type { - ElementType::Struct(a) | ElementType::Array(a) | ElementType::List(a) => a, + let buf = match self.element_type { + ElementType::Struct(buf) | ElementType::Array(buf) | ElementType::List(buf) => buf, _ => return None, }; - let list_iter = TLVListIterator { - buf: ptr.buf, - current: ptr.current, - left: ptr.left, - }; + let list_iter = TLVListIterator { buf, current: 0 }; Some(TLVContainerIterator { list_iter, prev_container: false, @@ -465,23 +421,23 @@ impl<'a> TLVElement<'a> { } } - pub fn confirm_struct(&self) -> Result, Error> { + pub fn confirm_struct(&self) -> Result<&TLVElement<'a>, Error> { match self.element_type { - ElementType::Struct(_) => Ok(*self), + ElementType::Struct(_) => Ok(self), _ => Err(ErrorCode::TLVTypeMismatch.into()), } } - pub fn confirm_array(&self) -> Result, Error> { + pub fn confirm_array(&self) -> Result<&TLVElement<'a>, Error> { match self.element_type { - ElementType::Array(_) => Ok(*self), + ElementType::Array(_) => Ok(self), _ => Err(ErrorCode::TLVTypeMismatch.into()), } } - pub fn confirm_list(&self) -> Result, Error> { + pub fn confirm_list(&self) -> Result<&TLVElement<'a>, Error> { match self.element_type { - ElementType::List(_) => Ok(*self), + ElementType::List(_) => Ok(self), _ => Err(ErrorCode::TLVTypeMismatch.into()), } } @@ -511,8 +467,8 @@ impl<'a> TLVElement<'a> { false } - pub fn get_element_type(&self) -> ElementType { - self.element_type + pub fn get_element_type(&self) -> &ElementType { + &self.element_type } } @@ -546,25 +502,19 @@ impl<'a> fmt::Display for TLVElement<'a> { } // This is a TLV List iterator, it only iterates over the individual TLVs in a TLV list -#[derive(Copy, Clone, Debug, PartialEq)] +#[derive(Clone, Debug, PartialEq)] pub struct TLVListIterator<'a> { buf: &'a [u8], current: usize, - left: usize, } impl<'a> TLVListIterator<'a> { - fn from_pointer(p: Pointer<'a>) -> Self { - Self { - buf: p.buf, - current: p.current, - left: p.left, - } + fn from_buf(buf: &'a [u8]) -> Self { + Self { buf, current: 0 } } fn advance(&mut self, len: usize) { self.current += len; - self.left -= len; } // Caller should ensure they are reading the _right_ tag at the _right_ place @@ -573,7 +523,7 @@ impl<'a> TLVListIterator<'a> { return None; } let tag_size = TAG_SIZE_MAP[tag_type as usize]; - if tag_size > self.left { + if tag_size > self.buf.len() - self.current { return None; } let tag = (TAG_EXTRACTOR[tag_type as usize])(self); @@ -586,7 +536,7 @@ impl<'a> TLVListIterator<'a> { return None; } let mut size = VALUE_SIZE_MAP[element_type as usize]; - if size > self.left { + if size > self.buf.len() - self.current { error!( "Invalid value found: {} self {:?} size {}", element_type, self, size @@ -609,7 +559,7 @@ impl<'a> Iterator for TLVListIterator<'a> { type Item = TLVElement<'a>; /* Code for going to the next Element */ fn next(&mut self) -> Option> { - if self.left < 1 { + if self.buf.len() - self.current < 1 { return None; } /* Read Control */ @@ -635,13 +585,12 @@ impl<'a> TLVList<'a> { pub fn iter(&self) -> TLVListIterator<'a> { TLVListIterator { current: 0, - left: self.buf.len(), buf: self.buf, } } } -fn is_container(element_type: ElementType) -> bool { +fn is_container(element_type: &ElementType) -> bool { matches!( element_type, ElementType::Struct(_) | ElementType::Array(_) | ElementType::List(_) @@ -680,7 +629,7 @@ impl<'a> TLVContainerIterator<'a> { nest_level -= 1; } _ => { - if is_container(element.element_type) { + if is_container(&element.element_type) { nest_level += 1; } } @@ -711,7 +660,7 @@ impl<'a> Iterator for TLVContainerIterator<'a> { return None; } - self.prev_container = is_container(element.element_type); + self.prev_container = is_container(&element.element_type); Some(element) } } @@ -724,19 +673,25 @@ pub fn get_root_node(b: &[u8]) -> Result { } pub fn get_root_node_struct(b: &[u8]) -> Result { - TLVList::new(b) + let root = TLVList::new(b) .iter() .next() - .ok_or(ErrorCode::InvalidData)? - .confirm_struct() + .ok_or(ErrorCode::InvalidData)?; + + root.confirm_struct()?; + + Ok(root) } pub fn get_root_node_list(b: &[u8]) -> Result { - TLVList::new(b) + let root = TLVList::new(b) .iter() .next() - .ok_or(ErrorCode::InvalidData)? - .confirm_list() + .ok_or(ErrorCode::InvalidData)?; + + root.confirm_list()?; + + Ok(root) } pub fn print_tlv_list(b: &[u8]) { @@ -798,8 +753,7 @@ mod tests { use log::info; use super::{ - get_root_node_list, get_root_node_struct, ElementType, Pointer, TLVElement, TLVList, - TagType, + get_root_node_list, get_root_node_struct, ElementType, TLVElement, TLVList, TagType, }; use crate::error::ErrorCode; @@ -859,11 +813,7 @@ mod tests { tlv_iter.next(), Some(TLVElement { tag_type: TagType::Context(0), - element_type: ElementType::Array(Pointer { - buf: &[21, 54, 0], - current: 3, - left: 0 - }), + element_type: ElementType::Array(&[]), }) ); } @@ -1123,7 +1073,8 @@ mod tests { // This is an array of CommandDataIB, but we'll only use the first element let cmd_data_ib = cmd_list_iter.next().unwrap(); - let cmd_path = cmd_data_ib.find_tag(0).unwrap().confirm_list().unwrap(); + let cmd_path = cmd_data_ib.find_tag(0).unwrap(); + let cmd_path = cmd_path.confirm_list().unwrap(); assert_eq!( cmd_path.find_tag(0).unwrap(), TLVElement { @@ -1188,11 +1139,7 @@ mod tests { 0x35, 0x1, 0x18, 0x18, 0x18, 0x18, ]; - let dummy_pointer = Pointer { - buf: &b, - current: 1, - left: 21, - }; + let dummy_pointer = &b[1..]; // These are the decoded elements that we expect from this input let verify_matrix: [(TagType, ElementType); 13] = [ (TagType::Anonymous, ElementType::Struct(dummy_pointer)), diff --git a/matter/src/tlv/traits.rs b/matter/src/tlv/traits.rs index 28c236be..9a8edcdd 100644 --- a/matter/src/tlv/traits.rs +++ b/matter/src/tlv/traits.rs @@ -298,7 +298,7 @@ impl ToTLV for Nullable { } } -#[derive(Copy, Clone)] +#[derive(Clone)] pub enum TLVArray<'a, T> { // This is used for the to-tlv path Slice(&'a [T]), @@ -317,7 +317,7 @@ impl<'a, T: ToTLV> TLVArray<'a, T> { } pub fn iter(&self) -> TLVArrayIter<'a, T> { - match *self { + match self { Self::Slice(s) => TLVArrayIter::Slice(s.iter()), Self::Ptr(p) => TLVArrayIter::Ptr(p.enter()), } @@ -401,7 +401,7 @@ impl<'a, T: FromTLV<'a> + Clone + ToTLV> ToTLV for TLVArray<'a, T> { impl<'a, T> FromTLV<'a> for TLVArray<'a, T> { fn from_tlv(t: &TLVElement<'a>) -> Result { t.confirm_array()?; - Ok(Self::Ptr(*t)) + Ok(Self::Ptr(t.clone())) } } diff --git a/matter/src/transport/mrp.rs b/matter/src/transport/mrp.rs index 2d046bf5..d9815919 100644 --- a/matter/src/transport/mrp.rs +++ b/matter/src/transport/mrp.rs @@ -41,7 +41,7 @@ impl RetransEntry { } } -#[derive(Debug, Copy, Clone)] +#[derive(Debug, Clone)] pub struct AckEntry { // The msg counter that we should acknowledge msg_ctr: u32, @@ -92,7 +92,7 @@ impl ReliableMessage { // Check any pending acknowledgements / retransmissions and take action pub fn is_ack_ready(&self, epoch: Epoch) -> bool { // Acknowledgements - if let Some(ack_entry) = self.ack { + if let Some(ack_entry) = &self.ack { ack_entry.has_timed_out(epoch) } else { false @@ -107,7 +107,7 @@ impl ReliableMessage { // Check if any acknowledgements are pending for this exchange, // if so, piggy back in the encoded header here - if let Some(ack_entry) = self.ack { + if let Some(ack_entry) = &self.ack { // Ack Entry exists, set ACK bit and remove from table proto_tx.proto.set_ack(ack_entry.get_msg_ctr()); self.ack = None; diff --git a/matter/src/transport/session.rs b/matter/src/transport/session.rs index 1e3a1d4f..1c2e9365 100644 --- a/matter/src/transport/session.rs +++ b/matter/src/transport/session.rs @@ -37,7 +37,7 @@ pub type NocCatIds = [u32; MAX_CAT_IDS_PER_NOC]; const MATTER_AES128_KEY_SIZE: usize = 16; -#[derive(Debug, Default, Copy, Clone, PartialEq)] +#[derive(Debug, Default, Clone, PartialEq)] pub struct CaseDetails { pub fab_idx: u8, pub cat_ids: NocCatIds, @@ -52,7 +52,7 @@ impl CaseDetails { } } -#[derive(Debug, PartialEq, Copy, Clone, Default)] +#[derive(Debug, PartialEq, Clone, Default)] pub enum SessionMode { // The Case session will capture the local fabric index Case(CaseDetails), @@ -149,7 +149,7 @@ impl Session { peer_sess_id: clone_from.peer_sess_id, msg_ctr: Self::rand_msg_ctr(rand), rx_ctr_state: RxCtrState::new(0), - mode: clone_from.mode, + mode: clone_from.mode.clone(), data: None, last_use: epoch(), } @@ -207,14 +207,14 @@ impl Session { } pub fn get_local_fabric_idx(&self) -> Option { - match self.mode { + match &self.mode { SessionMode::Case(a) => Some(a.fab_idx), _ => None, } } - pub fn get_session_mode(&self) -> SessionMode { - self.mode + pub fn get_session_mode(&self) -> &SessionMode { + &self.mode } pub fn get_msg_ctr(&mut self) -> u32 { @@ -454,7 +454,7 @@ impl SessionMgr { Some(self.get_session_handle(index)) } - fn get_or_add( + pub fn get_or_add( &mut self, sess_id: u16, peer_addr: Address, diff --git a/matter/tests/common/attributes.rs b/matter/tests/common/attributes.rs index 2ff95eb6..3a4f5e77 100644 --- a/matter/tests/common/attributes.rs +++ b/matter/tests/common/attributes.rs @@ -28,7 +28,7 @@ pub fn __assert_attr_report(received: &ReportDataMsg, expected: &[AttrResp], ski // We can't use assert_eq because it will also try to match data-version for inv_response in received.attr_reports.as_ref().unwrap().iter() { println!("Validating index {}", index); - match expected[index] { + match &expected[index] { AttrResp::Data(e_d) => match inv_response { AttrResp::Data(d) => { // We don't match the data-version @@ -41,7 +41,7 @@ pub fn __assert_attr_report(received: &ReportDataMsg, expected: &[AttrResp], ski panic!("Invalid response, expected AttrRespIn::Data"); } }, - AttrResp::Status(s) => assert_eq!(AttrResp::Status(s), inv_response), + AttrResp::Status(s) => assert_eq!(AttrResp::Status(s.clone()), inv_response), } println!("Index {} success", index); index += 1; diff --git a/matter/tests/common/commands.rs b/matter/tests/common/commands.rs index 419b6ac4..d1e0402d 100644 --- a/matter/tests/common/commands.rs +++ b/matter/tests/common/commands.rs @@ -30,15 +30,15 @@ pub enum ExpectedInvResp { pub fn assert_inv_response(resp: &msg::InvResp, expected: &[ExpectedInvResp]) { let mut index = 0; - for inv_response in resp.inv_responses.unwrap().iter() { + for inv_response in resp.inv_responses.as_ref().unwrap().iter() { println!("Validating index {}", index); - match expected[index] { + match &expected[index] { ExpectedInvResp::Cmd(e_c, e_d) => match inv_response { InvResp::Cmd(c) => { - assert_eq!(e_c, c.path); + assert_eq!(e_c, &c.path); match c.data { EncodeValue::Tlv(t) => { - assert_eq!(e_d, t.find_tag(0).unwrap().u8().unwrap()) + assert_eq!(*e_d, t.find_tag(0).unwrap().u8().unwrap()) } _ => panic!("Incorrect CmdDataType"), } @@ -49,7 +49,7 @@ pub fn assert_inv_response(resp: &msg::InvResp, expected: &[ExpectedInvResp]) { }, ExpectedInvResp::Status(e_status) => match inv_response { InvResp::Status(status) => { - assert_eq!(e_status, status); + assert_eq!(e_status, &status); } _ => { panic!("Invalid response, expected InvResponse::Status"); @@ -64,7 +64,7 @@ pub fn assert_inv_response(resp: &msg::InvResp, expected: &[ExpectedInvResp]) { #[macro_export] macro_rules! cmd_data { - ($path:ident, $data:literal) => { + ($path:expr, $data:literal) => { CmdData::new($path, EncodeValue::Value(&($data as u32))) }; } diff --git a/matter/tests/data_model/acl_and_dataver.rs b/matter/tests/data_model/acl_and_dataver.rs index 81f220ac..ebb831ca 100644 --- a/matter/tests/data_model/acl_and_dataver.rs +++ b/matter/tests/data_model/acl_and_dataver.rs @@ -429,7 +429,7 @@ fn write_with_runtime_acl_add() { peer, None, // write to echo-cluster attribute, write to acl attribute, write to echo-cluster attribute - &[input0, acl_input, input0], + &[input0.clone(), acl_input, input0], &[ AttrStatus::new(&ep0_att, IMStatusCode::UnsupportedAccess, 0), AttrStatus::new(&acl_att, IMStatusCode::Success, 0), @@ -597,7 +597,7 @@ fn test_write_data_ver() { let input_correct_dataver = &[AttrData::new( Some(initial_data_ver), AttrPath::new(&ep0_attwrite), - attr_data1, + attr_data1.clone(), )]; im.handle_write_reqs( peer, diff --git a/matter/tests/data_model/attribute_lists.rs b/matter/tests/data_model/attribute_lists.rs index aaa2b635..636c9c0a 100644 --- a/matter/tests/data_model/attribute_lists.rs +++ b/matter/tests/data_model/attribute_lists.rs @@ -59,7 +59,11 @@ fn attr_list_ops() { let mut att_path = AttrPath::new(&att_data); // Test 1: Add Operation - add val0 - let input = &[AttrData::new(None, att_path, EncodeValue::Value(&val0))]; + let input = &[AttrData::new( + None, + att_path.clone(), + EncodeValue::Value(&val0), + )]; let expected = &[AttrStatus::new(&att_data, IMStatusCode::Success, 0)]; ImEngine::new_with_write_reqs(&matter(&mut DummyMdns), input, expected); @@ -69,7 +73,11 @@ fn attr_list_ops() { } // Test 2: Another Add Operation - add val1 - let input = &[AttrData::new(None, att_path, EncodeValue::Value(&val1))]; + let input = &[AttrData::new( + None, + att_path.clone(), + EncodeValue::Value(&val1), + )]; let expected = &[AttrStatus::new(&att_data, IMStatusCode::Success, 0)]; ImEngine::new_with_write_reqs(&matter(&mut DummyMdns), input, expected); @@ -80,7 +88,11 @@ fn attr_list_ops() { // Test 3: Edit Operation - edit val1 to val0 att_path.list_index = Some(Nullable::NotNull(1)); - let input = &[AttrData::new(None, att_path, EncodeValue::Value(&val0))]; + let input = &[AttrData::new( + None, + att_path.clone(), + EncodeValue::Value(&val0), + )]; let expected = &[AttrStatus::new(&att_data, IMStatusCode::Success, 0)]; ImEngine::new_with_write_reqs(&matter(&mut DummyMdns), input, expected); @@ -91,7 +103,7 @@ fn attr_list_ops() { // Test 4: Delete Operation - delete index 0 att_path.list_index = Some(Nullable::NotNull(0)); - let input = &[AttrData::new(None, att_path, delete_item)]; + let input = &[AttrData::new(None, att_path.clone(), delete_item)]; let expected = &[AttrStatus::new(&att_data, IMStatusCode::Success, 0)]; ImEngine::new_with_write_reqs(&matter(&mut DummyMdns), input, expected); @@ -105,7 +117,7 @@ fn attr_list_ops() { att_path.list_index = None; let input = &[AttrData::new( None, - att_path, + att_path.clone(), EncodeValue::Value(&overwrite_val), )]; let expected = &[AttrStatus::new(&att_data, IMStatusCode::Success, 0)]; diff --git a/matter/tests/data_model/attributes.rs b/matter/tests/data_model/attributes.rs index 6d1072cd..7d185268 100644 --- a/matter/tests/data_model/attributes.rs +++ b/matter/tests/data_model/attributes.rs @@ -218,7 +218,7 @@ fn test_read_wc_endpoint_wc_attribute() { Some(echo_cluster::ID), Some(GlobalElements::AttributeList as u32), ), - attr_list_tlv.get_element_type() + attr_list_tlv.get_element_type().clone() ), attr_data_path!( GenericPath::new( @@ -258,7 +258,7 @@ fn test_read_wc_endpoint_wc_attribute() { Some(echo_cluster::ID), Some(GlobalElements::AttributeList as u32), ), - attr_list_tlv.get_element_type() + attr_list_tlv.get_element_type().clone() ), attr_data_path!( GenericPath::new( diff --git a/matter/tests/data_model/commands.rs b/matter/tests/data_model/commands.rs index a232f269..0d9c0c3f 100644 --- a/matter/tests/data_model/commands.rs +++ b/matter/tests/data_model/commands.rs @@ -75,10 +75,10 @@ fn test_invoke_cmds_unsupported_fields() { let invalid_command = CmdPath::new(Some(0), Some(echo_cluster::ID), Some(0x1234)); let invalid_command_wc_endpoint = CmdPath::new(None, Some(echo_cluster::ID), Some(0x1234)); let input = &[ - cmd_data!(invalid_endpoint, 5), - cmd_data!(invalid_cluster, 5), + cmd_data!(invalid_endpoint.clone(), 5), + cmd_data!(invalid_cluster.clone(), 5), cmd_data!(invalid_cluster_wc_endpoint, 5), - cmd_data!(invalid_command, 5), + cmd_data!(invalid_command.clone(), 5), cmd_data!(invalid_command_wc_endpoint, 5), ]; diff --git a/matter/tests/data_model/long_reads.rs b/matter/tests/data_model/long_reads.rs index a396cc0d..21c25595 100644 --- a/matter/tests/data_model/long_reads.rs +++ b/matter/tests/data_model/long_reads.rs @@ -77,163 +77,228 @@ fn wildcard_read_resp(part: u8) -> Vec> { // For brevity, we only check the AttrPath, not the actual 'data' let dont_care = ElementType::U8(0); let part1 = vec![ - attr_data!(0, 29, GlobalElements::FeatureMap, dont_care), - attr_data!(0, 29, GlobalElements::AttributeList, dont_care), - attr_data!(0, 29, descriptor::Attributes::DeviceTypeList, dont_care), - attr_data!(0, 29, descriptor::Attributes::ServerList, dont_care), - attr_data!(0, 29, descriptor::Attributes::PartsList, dont_care), - attr_data!(0, 29, descriptor::Attributes::ClientList, dont_care), - attr_data!(0, 40, GlobalElements::FeatureMap, dont_care), - attr_data!(0, 40, GlobalElements::AttributeList, dont_care), + attr_data!(0, 29, GlobalElements::FeatureMap, dont_care.clone()), + attr_data!(0, 29, GlobalElements::AttributeList, dont_care.clone()), + attr_data!( + 0, + 29, + descriptor::Attributes::DeviceTypeList, + dont_care.clone() + ), + attr_data!(0, 29, descriptor::Attributes::ServerList, dont_care.clone()), + attr_data!(0, 29, descriptor::Attributes::PartsList, dont_care.clone()), + attr_data!(0, 29, descriptor::Attributes::ClientList, dont_care.clone()), + attr_data!(0, 40, GlobalElements::FeatureMap, dont_care.clone()), + attr_data!(0, 40, GlobalElements::AttributeList, dont_care.clone()), attr_data!( 0, 40, basic_info::AttributesDiscriminants::DMRevision, - dont_care + dont_care.clone() ), attr_data!( 0, 40, basic_info::AttributesDiscriminants::VendorId, - dont_care + dont_care.clone() ), attr_data!( 0, 40, basic_info::AttributesDiscriminants::ProductId, - dont_care + dont_care.clone() + ), + attr_data!( + 0, + 40, + basic_info::AttributesDiscriminants::HwVer, + dont_care.clone() + ), + attr_data!( + 0, + 40, + basic_info::AttributesDiscriminants::SwVer, + dont_care.clone() ), - attr_data!(0, 40, basic_info::AttributesDiscriminants::HwVer, dont_care), - attr_data!(0, 40, basic_info::AttributesDiscriminants::SwVer, dont_care), attr_data!( 0, 40, basic_info::AttributesDiscriminants::SwVerString, - dont_care + dont_care.clone() ), attr_data!( 0, 40, basic_info::AttributesDiscriminants::SerialNo, - dont_care + dont_care.clone() ), - attr_data!(0, 48, GlobalElements::FeatureMap, dont_care), - attr_data!(0, 48, GlobalElements::AttributeList, dont_care), + attr_data!(0, 48, GlobalElements::FeatureMap, dont_care.clone()), + attr_data!(0, 48, GlobalElements::AttributeList, dont_care.clone()), attr_data!( 0, 48, gen_comm::AttributesDiscriminants::BreadCrumb, - dont_care + dont_care.clone() ), attr_data!( 0, 48, gen_comm::AttributesDiscriminants::RegConfig, - dont_care + dont_care.clone() ), attr_data!( 0, 48, gen_comm::AttributesDiscriminants::LocationCapability, - dont_care + dont_care.clone() ), attr_data!( 0, 48, gen_comm::AttributesDiscriminants::BasicCommissioningInfo, - dont_care + dont_care.clone() ), - attr_data!(0, 49, GlobalElements::FeatureMap, dont_care), - attr_data!(0, 49, GlobalElements::AttributeList, dont_care), - attr_data!(0, 60, GlobalElements::FeatureMap, dont_care), - attr_data!(0, 60, GlobalElements::AttributeList, dont_care), + attr_data!(0, 49, GlobalElements::FeatureMap, dont_care.clone()), + attr_data!(0, 49, GlobalElements::AttributeList, dont_care.clone()), + attr_data!(0, 60, GlobalElements::FeatureMap, dont_care.clone()), + attr_data!(0, 60, GlobalElements::AttributeList, dont_care.clone()), attr_data!( 0, 60, adm_comm::AttributesDiscriminants::WindowStatus, - dont_care + dont_care.clone() ), attr_data!( 0, 60, adm_comm::AttributesDiscriminants::AdminFabricIndex, - dont_care + dont_care.clone() ), attr_data!( 0, 60, adm_comm::AttributesDiscriminants::AdminVendorId, - dont_care + dont_care.clone() ), - attr_data!(0, 62, GlobalElements::FeatureMap, dont_care), - attr_data!(0, 62, GlobalElements::AttributeList, dont_care), + attr_data!(0, 62, GlobalElements::FeatureMap, dont_care.clone()), + attr_data!(0, 62, GlobalElements::AttributeList, dont_care.clone()), attr_data!( 0, 62, noc::AttributesDiscriminants::CurrentFabricIndex, - dont_care + dont_care.clone() + ), + attr_data!( + 0, + 62, + noc::AttributesDiscriminants::Fabrics, + dont_care.clone() ), - attr_data!(0, 62, noc::AttributesDiscriminants::Fabrics, dont_care), attr_data!( 0, 62, noc::AttributesDiscriminants::SupportedFabrics, - dont_care + dont_care.clone() ), attr_data!( 0, 62, noc::AttributesDiscriminants::CommissionedFabrics, - dont_care + dont_care.clone() + ), + attr_data!(0, 31, GlobalElements::FeatureMap, dont_care.clone()), + attr_data!(0, 31, GlobalElements::AttributeList, dont_care.clone()), + attr_data!(0, 31, acl::AttributesDiscriminants::Acl, dont_care.clone()), + attr_data!( + 0, + 31, + acl::AttributesDiscriminants::Extension, + dont_care.clone() ), - attr_data!(0, 31, GlobalElements::FeatureMap, dont_care), - attr_data!(0, 31, GlobalElements::AttributeList, dont_care), - attr_data!(0, 31, acl::AttributesDiscriminants::Acl, dont_care), - attr_data!(0, 31, acl::AttributesDiscriminants::Extension, dont_care), attr_data!( 0, 31, acl::AttributesDiscriminants::SubjectsPerEntry, - dont_care + dont_care.clone() ), attr_data!( 0, 31, acl::AttributesDiscriminants::TargetsPerEntry, - dont_care + dont_care.clone() ), attr_data!( 0, 31, acl::AttributesDiscriminants::EntriesPerFabric, - dont_care + dont_care.clone() + ), + attr_data!(0, echo::ID, GlobalElements::FeatureMap, dont_care.clone()), + attr_data!( + 0, + echo::ID, + GlobalElements::AttributeList, + dont_care.clone() + ), + attr_data!( + 0, + echo::ID, + echo::AttributesDiscriminants::Att1, + dont_care.clone() + ), + attr_data!( + 0, + echo::ID, + echo::AttributesDiscriminants::Att2, + dont_care.clone() ), - attr_data!(0, echo::ID, GlobalElements::FeatureMap, dont_care), - attr_data!(0, echo::ID, GlobalElements::AttributeList, dont_care), - attr_data!(0, echo::ID, echo::AttributesDiscriminants::Att1, dont_care), - attr_data!(0, echo::ID, echo::AttributesDiscriminants::Att2, dont_care), attr_data!( 0, echo::ID, echo::AttributesDiscriminants::AttCustom, - dont_care + dont_care.clone() + ), + attr_data!(1, 29, GlobalElements::FeatureMap, dont_care.clone()), + attr_data!(1, 29, GlobalElements::AttributeList, dont_care.clone()), + attr_data!( + 1, + 29, + descriptor::Attributes::DeviceTypeList, + dont_care.clone() ), - attr_data!(1, 29, GlobalElements::FeatureMap, dont_care), - attr_data!(1, 29, GlobalElements::AttributeList, dont_care), - attr_data!(1, 29, descriptor::Attributes::DeviceTypeList, dont_care), ]; let part2 = vec![ - attr_data!(1, 29, descriptor::Attributes::ServerList, dont_care), - attr_data!(1, 29, descriptor::Attributes::PartsList, dont_care), - attr_data!(1, 29, descriptor::Attributes::ClientList, dont_care), - attr_data!(1, 6, GlobalElements::FeatureMap, dont_care), - attr_data!(1, 6, GlobalElements::AttributeList, dont_care), - attr_data!(1, 6, onoff::AttributesDiscriminants::OnOff, dont_care), - attr_data!(1, echo::ID, GlobalElements::FeatureMap, dont_care), - attr_data!(1, echo::ID, GlobalElements::AttributeList, dont_care), - attr_data!(1, echo::ID, echo::AttributesDiscriminants::Att1, dont_care), - attr_data!(1, echo::ID, echo::AttributesDiscriminants::Att2, dont_care), + attr_data!(1, 29, descriptor::Attributes::ServerList, dont_care.clone()), + attr_data!(1, 29, descriptor::Attributes::PartsList, dont_care.clone()), + attr_data!(1, 29, descriptor::Attributes::ClientList, dont_care.clone()), + attr_data!(1, 6, GlobalElements::FeatureMap, dont_care.clone()), + attr_data!(1, 6, GlobalElements::AttributeList, dont_care.clone()), + attr_data!( + 1, + 6, + onoff::AttributesDiscriminants::OnOff, + dont_care.clone() + ), + attr_data!(1, echo::ID, GlobalElements::FeatureMap, dont_care.clone()), + attr_data!( + 1, + echo::ID, + GlobalElements::AttributeList, + dont_care.clone() + ), + attr_data!( + 1, + echo::ID, + echo::AttributesDiscriminants::Att1, + dont_care.clone() + ), + attr_data!( + 1, + echo::ID, + echo::AttributesDiscriminants::Att2, + dont_care.clone() + ), attr_data!( 1, echo::ID, diff --git a/matter_macro_derive/src/lib.rs b/matter_macro_derive/src/lib.rs index c63eddc4..619bf3b4 100644 --- a/matter_macro_derive/src/lib.rs +++ b/matter_macro_derive/src/lib.rs @@ -323,7 +323,7 @@ fn gen_fromtlv_for_struct( let mut t_iter = t.#datatype ()?.enter().ok_or_else(|| #krate::error::Error::new(#krate::error::ErrorCode::Invalid))?; let mut item = t_iter.next(); #( - let #idents = if Some(true) == item.map(|x| x.check_ctx_tag(#tags)) { + let #idents = if Some(true) == item.as_ref().map(|x| x.check_ctx_tag(#tags)) { let backup = item; item = t_iter.next(); #types::from_tlv(&backup.unwrap()) From 931e30601ee3474d781524bd0c58f02b7bf56a5d Mon Sep 17 00:00:00 2001 From: ivmarkov Date: Sun, 28 May 2023 11:13:02 +0000 Subject: [PATCH 46/72] Clippy --- matter/src/mdns.rs | 33 +++++++++++++++++-------------- matter/tests/interaction_model.rs | 2 +- 2 files changed, 19 insertions(+), 16 deletions(-) diff --git a/matter/src/mdns.rs b/matter/src/mdns.rs index 80333b14..7564578a 100644 --- a/matter/src/mdns.rs +++ b/matter/src/mdns.rs @@ -244,6 +244,7 @@ pub mod builtin { const IP_BIND_ADDR: (IpAddr, u16) = (IpAddr::V6(Ipv6Addr::UNSPECIFIED), 5353); + #[allow(clippy::too_many_arguments)] pub fn create_record( id: u16, hostname: &str, @@ -552,6 +553,7 @@ pub mod builtin { select(&mut broadcast, &mut respond).await.unwrap() } + #[allow(clippy::await_holding_refcell_ref)] async fn broadcast(&self) -> Result<(), Error> { loop { select( @@ -562,29 +564,30 @@ pub mod builtin { let mut index = 0; - while let Some(entry) = self - .0 - .entries - .borrow() - .get(index) - .map(|entry| entry.clone()) - { - info!("Broadasting mDNS entry {}", &entry.key); + loop { + let entry = self.0.entries.borrow().get(index).cloned(); - self.0.bind().await?; + if let Some(entry) = entry { + info!("Broadasting mDNS entry {}", &entry.key); - let udp = self.0.udp.borrow(); - let udp = udp.as_ref().unwrap(); + self.0.bind().await?; - for (addr, port) in IP_BROADCAST_ADDRS { - udp.send(SocketAddr::new(addr, port), &entry.record).await?; - } + let udp = self.0.udp.borrow(); + let udp = udp.as_ref().unwrap(); + + for (addr, port) in IP_BROADCAST_ADDRS { + udp.send(SocketAddr::new(addr, port), &entry.record).await?; + } - index += 1; + index += 1; + } else { + break; + } } } } + #[allow(clippy::await_holding_refcell_ref)] async fn respond(&self) -> Result<(), Error> { loop { let mut buf = [0; 1580]; diff --git a/matter/tests/interaction_model.rs b/matter/tests/interaction_model.rs index 5d2c21a6..9642ab23 100644 --- a/matter/tests/interaction_model.rs +++ b/matter/tests/interaction_model.rs @@ -68,7 +68,7 @@ impl DataHandler for DataModel { continue; }; let cmd_path_ib = i.path; - let mut common_data = &mut self.node; + let common_data = &mut self.node; common_data.endpoint = cmd_path_ib.path.endpoint.unwrap_or(1); common_data.cluster = cmd_path_ib.path.cluster.unwrap_or(0); common_data.command = cmd_path_ib.path.leaf.unwrap_or(0) as u16; From 443324a76427a693964c70442e88c1fee21e9a98 Mon Sep 17 00:00:00 2001 From: ivmarkov Date: Sun, 28 May 2023 11:45:27 +0000 Subject: [PATCH 47/72] More inlines --- matter/src/secure_channel/case.rs | 1 + matter/src/secure_channel/core.rs | 1 + matter/src/secure_channel/pake.rs | 2 +- 3 files changed, 3 insertions(+), 1 deletion(-) diff --git a/matter/src/secure_channel/case.rs b/matter/src/secure_channel/case.rs index c029963a..63d5e56e 100644 --- a/matter/src/secure_channel/case.rs +++ b/matter/src/secure_channel/case.rs @@ -74,6 +74,7 @@ pub struct Case<'a> { } impl<'a> Case<'a> { + #[inline(always)] pub fn new(fabric_mgr: &'a RefCell, rand: Rand) -> Self { Self { fabric_mgr, rand } } diff --git a/matter/src/secure_channel/core.rs b/matter/src/secure_channel/core.rs index c2fe059f..7287ae00 100644 --- a/matter/src/secure_channel/core.rs +++ b/matter/src/secure_channel/core.rs @@ -41,6 +41,7 @@ pub struct SecureChannel<'a> { } impl<'a> SecureChannel<'a> { + #[inline(always)] pub fn new( pase: &'a RefCell, fabric_mgr: &'a RefCell, diff --git a/matter/src/secure_channel/pake.rs b/matter/src/secure_channel/pake.rs index cd0ffaf7..ab095122 100644 --- a/matter/src/secure_channel/pake.rs +++ b/matter/src/secure_channel/pake.rs @@ -43,7 +43,6 @@ enum PaseMgrState { Disabled, } -// Could this lock be avoided? pub struct PaseMgr { state: PaseMgrState, epoch: Epoch, @@ -51,6 +50,7 @@ pub struct PaseMgr { } impl PaseMgr { + #[inline(always)] pub fn new(epoch: Epoch, rand: Rand) -> Self { Self { state: PaseMgrState::Disabled, From 2cde37899d333bec0421aac0e2d40f41877d5e4a Mon Sep 17 00:00:00 2001 From: ivmarkov Date: Sun, 28 May 2023 14:05:43 +0000 Subject: [PATCH 48/72] Make the example working again --- Cargo.toml | 7 + examples/onoff_light/src/main.rs | 267 ++++++++++++++++++++++++--- matter/Cargo.toml | 4 + matter/src/interaction_model/core.rs | 4 +- 4 files changed, 255 insertions(+), 27 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 2e964561..7561a523 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -8,3 +8,10 @@ exclude = ["examples/*"] smol = { git = "https://github.com/esp-rs-compat/smol" } polling = { git = "https://github.com/esp-rs-compat/polling" } socket2 = { git = "https://github.com/esp-rs-compat/socket2" } + +[profile.release] +opt-level = 3 + +[profile.dev] +debug = true +opt-level = 3 diff --git a/examples/onoff_light/src/main.rs b/examples/onoff_light/src/main.rs index 9cae24b5..20b1cc9b 100644 --- a/examples/onoff_light/src/main.rs +++ b/examples/onoff_light/src/main.rs @@ -15,9 +15,10 @@ * limitations under the License. */ -use std::borrow::Borrow; -use std::error::Error; +use core::borrow::Borrow; +use core::pin::pin; +use embassy_futures::select::select; use log::info; use matter::core::{CommissioningData, Matter}; use matter::data_model::cluster_basic_information::BasicInfoConfig; @@ -28,37 +29,74 @@ use matter::data_model::objects::*; use matter::data_model::root_endpoint; use matter::data_model::sdm::dev_att::DevAttDataFetcher; use matter::data_model::system_model::descriptor; +use matter::error::Error; use matter::interaction_model::core::InteractionModel; +use matter::mdns::builtin::Mdns; use matter::persist; use matter::secure_channel::spake2p::VerifierData; +use matter::transport::network::{Address, IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr}; use matter::transport::{ mgr::RecvAction, mgr::TransportMgr, packet::MAX_RX_BUF_SIZE, packet::MAX_TX_BUF_SIZE, udp::UdpListener, }; +use matter::utils::select::EitherUnwrap; mod dev_att; -fn main() -> Result<(), impl Error> { - env_logger::init_from_env( - env_logger::Env::default().filter_or(env_logger::DEFAULT_FILTER_ENV, "info"), +#[cfg(feature = "std")] +fn main() -> Result<(), Error> { + let thread = std::thread::Builder::new() + .stack_size(120 * 1024) + .spawn(move || run()) + .unwrap(); + + thread.join().unwrap() + // run() +} + +#[cfg(not(feature = "std"))] +#[no_mangle] +fn main() { + run().unwrap(); +} + +fn run() -> Result<(), Error> { + initialize_logger(); + + info!( + "Matter memory: mDNS={}, Transport={} (of which Matter={})", + core::mem::size_of::(), + core::mem::size_of::(), + core::mem::size_of::(), ); - // vid/pid should match those in the DAC - let dev_info = BasicInfoConfig { - vid: 0xFFF1, - pid: 0x8000, - hw_ver: 2, - sw_ver: 1, - sw_ver_str: "1", - serial_no: "aabbccdd", - device_name: "OnOff Light", - }; + let (ipv4_addr, ipv6_addr) = initialize_network()?; - //let mut mdns = matter::mdns::astro::AstroMdns::new()?; - let mut mdns = matter::mdns::libmdns::LibMdns::new()?; - //let mut mdns = matter::mdns::DummyMdns {}; + let mut mdns = matter::mdns::builtin::Mdns::new( + 0, + "matter-demo", + ipv4_addr.octets(), + Some(ipv6_addr.octets()), + ); - let matter = Matter::new_default(&dev_info, &mut mdns, matter::MATTER_PORT); + let (mut mdns, mut mdns_runner) = mdns.split(); + //let (mut mdns, mdns_runner) = (matter::mdns::astro::AstroMdns::new()?, core::future::pending::pending()); + //let (mut mdns, mdns_runner) = (matter::mdns::DummyMdns {}, core::future::pending::pending()); + + let matter = Matter::new_default( + // vid/pid should match those in the DAC + &BasicInfoConfig { + vid: 0xFFF1, + pid: 0x8000, + hw_ver: 2, + sw_ver: 1, + sw_ver_str: "1", + serial_no: "aabbccdd", + device_name: "OnOff Light", + }, + &mut mdns, + matter::MATTER_PORT, + ); let dev_att = dev_att::HardCodedDevAtt::new(); @@ -88,11 +126,17 @@ fn main() -> Result<(), impl Error> { let matter = &matter; let dev_att = &dev_att; + let mdns_runner = &mut mdns_runner; let mut transport = TransportMgr::new(matter); + let transport = &mut transport; - smol::block_on(async move { - let udp = UdpListener::new().await?; + let mut io_fut = pin!(async move { + let udp = UdpListener::new(SocketAddr::new( + IpAddr::V6(Ipv6Addr::UNSPECIFIED), + matter::MATTER_PORT, + )) + .await?; loop { let mut rx_buf = [0; MAX_RX_BUF_SIZE]; @@ -100,12 +144,13 @@ fn main() -> Result<(), impl Error> { let (len, addr) = udp.recv(&mut rx_buf).await?; - let mut completion = transport.recv(addr, &mut rx_buf[..len], &mut tx_buf); + let mut completion = + transport.recv(Address::Udp(addr), &mut rx_buf[..len], &mut tx_buf); while let Some(action) = completion.next_action()? { match action { RecvAction::Send(addr, buf) => { - udp.send(addr, buf).await?; + udp.send(addr.unwrap_udp(), buf).await?; } RecvAction::Interact(mut ctx) => { let node = Node { @@ -127,7 +172,8 @@ fn main() -> Result<(), impl Error> { if im.handle(&mut ctx)? { if ctx.send()? { - udp.send(ctx.tx.peer, ctx.tx.as_slice()).await?; + udp.send(ctx.tx.peer.unwrap_udp(), ctx.tx.as_slice()) + .await?; } } } @@ -145,7 +191,13 @@ fn main() -> Result<(), impl Error> { #[allow(unreachable_code)] Ok::<_, matter::error::Error>(()) - })?; + }); + + let mut mdns_fut = pin!(async move { mdns_runner.run().await }); + + let mut fut = pin!(async move { select(&mut io_fut, &mut mdns_fut,).await.unwrap() }); + + smol::block_on(&mut fut)?; Ok::<_, matter::error::Error>(()) } @@ -163,3 +215,168 @@ fn handler<'a>(matter: &'a Matter<'a>, dev_att: &'a dyn DevAttDataFetcher) -> im cluster_on_off::OnOffCluster::new(*matter.borrow()), ) } + +#[cfg(not(target_os = "espidf"))] +#[inline(never)] +fn initialize_logger() { + env_logger::init_from_env( + env_logger::Env::default().filter_or(env_logger::DEFAULT_FILTER_ENV, "info"), + ); +} + +#[cfg(not(target_os = "espidf"))] +#[inline(never)] +fn initialize_network() -> Result<(Ipv4Addr, Ipv6Addr), Error> { + use log::error; + use matter::error::ErrorCode; + use nix::{net::if_::InterfaceFlags, sys::socket::SockaddrIn6}; + + let interfaces = || { + nix::ifaddrs::getifaddrs().unwrap().filter(|ia| { + ia.flags + .contains(InterfaceFlags::IFF_UP | InterfaceFlags::IFF_BROADCAST) + && !ia + .flags + .intersects(InterfaceFlags::IFF_LOOPBACK | InterfaceFlags::IFF_POINTOPOINT) + }) + }; + + // A quick and dirty way to get a network interface that has a link-local IPv6 address assigned as well as a non-loopback IPv4 + // Most likely, this is the interface we need + // (as opposed to all the docker and libvirt interfaces that might be assigned on the machine and which seem by default to be IPv4 only) + let (iname, ip, ipv6) = interfaces() + .filter_map(|ia| { + ia.address + .and_then(|addr| addr.as_sockaddr_in6().map(SockaddrIn6::ip)) + .filter(|ip| ip.octets()[..2] == [0xfe, 0x80]) + .map(|ipv6| (ia.interface_name, ipv6)) + }) + .filter_map(|(iname, ipv6)| { + interfaces() + .filter(|ia2| ia2.interface_name == iname) + .find_map(|ia2| { + ia2.address + .and_then(|addr| addr.as_sockaddr_in().map(|addr| addr.ip().into())) + .map(|ip| (iname.clone(), ip, ipv6)) + }) + }) + .next() + .ok_or_else(|| { + error!("Cannot find network interface suitable for mDNS broadcasting"); + ErrorCode::Network + })?; + + info!( + "Will use network interface {} with {}/{} for mDNS", + iname, ip, ipv6 + ); + + Ok((ip, ipv6)) +} + +#[cfg(target_os = "espidf")] +#[inline(never)] +fn initialize_logger() { + esp_idf_svc::log::EspLogger::initialize_default(); +} + +#[cfg(target_os = "espidf")] +#[inline(never)] +fn initialize_network() -> Result<(Ipv4Addr, Ipv6Addr), Error> { + use core::time::Duration; + + use embedded_svc::wifi::{AuthMethod, ClientConfiguration, Configuration}; + use esp_idf_hal::prelude::Peripherals; + use esp_idf_svc::handle::RawHandle; + use esp_idf_svc::wifi::{BlockingWifi, EspWifi}; + use esp_idf_svc::{eventloop::EspSystemEventLoop, nvs::EspDefaultNvsPartition}; + use esp_idf_sys::{ + self as _, esp, esp_ip6_addr_t, esp_netif_create_ip6_linklocal, esp_netif_get_ip6_linklocal, + }; // If using the `binstart` feature of `esp-idf-sys`, always keep this module imported + + const SSID: &'static str = env!("WIFI_SSID"); + const PASSWORD: &'static str = env!("WIFI_PASS"); + + #[allow(clippy::needless_update)] + { + // VFS is necessary for poll-based async IO + esp_idf_sys::esp!(unsafe { + esp_idf_sys::esp_vfs_eventfd_register(&esp_idf_sys::esp_vfs_eventfd_config_t { + max_fds: 5, + ..Default::default() + }) + })?; + } + + let peripherals = Peripherals::take().unwrap(); + let sys_loop = EspSystemEventLoop::take()?; + let nvs = EspDefaultNvsPartition::take()?; + + let mut wifi = EspWifi::new(peripherals.modem, sys_loop.clone(), Some(nvs))?; + + let mut bwifi = BlockingWifi::wrap(&mut wifi, sys_loop)?; + + let wifi_configuration: Configuration = Configuration::Client(ClientConfiguration { + ssid: SSID.into(), + bssid: None, + auth_method: AuthMethod::WPA2Personal, + password: PASSWORD.into(), + channel: None, + }); + + bwifi.set_configuration(&wifi_configuration)?; + + bwifi.start()?; + info!("Wifi started"); + + bwifi.connect()?; + info!("Wifi connected"); + + esp!(unsafe { + esp_netif_create_ip6_linklocal(bwifi.wifi_mut().sta_netif_mut().handle() as _) + })?; + + bwifi.wait_netif_up()?; + info!("Wifi netif up"); + + let ip_info = wifi.sta_netif().get_ip_info()?; + + let mut ipv6: esp_ip6_addr_t = Default::default(); + + info!("Waiting for IPv6 address"); + + while esp!(unsafe { esp_netif_get_ip6_linklocal(wifi.sta_netif().handle() as _, &mut ipv6) }) + .is_err() + { + info!("Waiting..."); + std::thread::sleep(Duration::from_secs(2)); + } + + info!("Wifi DHCP info: {:?}, IPv6: {:?}", ip_info, ipv6.addr); + + let ipv4_octets = ip_info.ip.octets(); + let ipv6_octets = [ + ipv6.addr[0].to_le_bytes()[0], + ipv6.addr[0].to_le_bytes()[1], + ipv6.addr[0].to_le_bytes()[2], + ipv6.addr[0].to_le_bytes()[3], + ipv6.addr[1].to_le_bytes()[0], + ipv6.addr[1].to_le_bytes()[1], + ipv6.addr[1].to_le_bytes()[2], + ipv6.addr[1].to_le_bytes()[3], + ipv6.addr[2].to_le_bytes()[0], + ipv6.addr[2].to_le_bytes()[1], + ipv6.addr[2].to_le_bytes()[2], + ipv6.addr[2].to_le_bytes()[3], + ipv6.addr[3].to_le_bytes()[0], + ipv6.addr[3].to_le_bytes()[1], + ipv6.addr[3].to_le_bytes()[2], + ipv6.addr[3].to_le_bytes()[3], + ]; + + // Not OK of course, but for a demo this is good enough + // Wifi will continue to be available and working in the background + core::mem::forget(wifi); + + Ok((ipv4_octets.into(), ipv6_octets.into())) +} diff --git a/matter/Cargo.toml b/matter/Cargo.toml index 22ef439e..d1bcced1 100644 --- a/matter/Cargo.toml +++ b/matter/Cargo.toml @@ -74,10 +74,14 @@ x509-cert = { version = "0.2.0", default-features = false, features = ["pem"], o [target.'cfg(not(target_os = "espidf"))'.dependencies] mbedtls = { git = "https://github.com/fortanix/rust-mbedtls", optional = true } env_logger = { version = "0.10.0", optional = true } +nix = { version = "0.26", features = ["net"] } [target.'cfg(target_os = "espidf")'.dependencies] esp-idf-sys = { version = "0.33", default-features = false, features = ["native"] } +[[example]] +name = "onoff_light" +path = "../examples/onoff_light/src/main.rs" [[example]] name = "speaker" diff --git a/matter/src/interaction_model/core.rs b/matter/src/interaction_model/core.rs index e24ec079..0686061c 100644 --- a/matter/src/interaction_model/core.rs +++ b/matter/src/interaction_model/core.rs @@ -206,8 +206,8 @@ impl<'a, 'b> Transaction<'a, 'b> { /* Interaction Model ID as per the Matter Spec */ const PROTO_ID_INTERACTION_MODEL: usize = 0x01; -const MAX_RESUME_PATHS: usize = 128; -const MAX_RESUME_DATAVER_FILTERS: usize = 128; +const MAX_RESUME_PATHS: usize = 32; +const MAX_RESUME_DATAVER_FILTERS: usize = 32; // This is the amount of space we reserve for other things to be attached towards // the end of long reads. From b94484b67e90936d0e5704a0a81e0944bb7ec7b9 Mon Sep 17 00:00:00 2001 From: ivmarkov Date: Wed, 31 May 2023 12:51:37 +0000 Subject: [PATCH 49/72] Make sure nix is not brought in no-std compiles --- matter/Cargo.toml | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/matter/Cargo.toml b/matter/Cargo.toml index d1bcced1..410b30ce 100644 --- a/matter/Cargo.toml +++ b/matter/Cargo.toml @@ -16,8 +16,9 @@ path = "src/lib.rs" [features] default = ["os", "crypto_rustcrypto"] -os = ["std", "backtrace", "critical-section/std", "embassy-sync/std", "embassy-time/std"] -std = ["alloc", "env_logger", "rand", "qrcode", "async-io", "smol", "esp-idf-sys/std"] +#default = ["crypto_rustcrypto"] +os = ["std", "backtrace", "env_logger", "nix", "critical-section/std", "embassy-sync/std", "embassy-time/std"] +std = ["alloc", "rand", "qrcode", "async-io", "smol", "esp-idf-sys/std"] backtrace = [] alloc = [] nightly = [] @@ -74,7 +75,7 @@ x509-cert = { version = "0.2.0", default-features = false, features = ["pem"], o [target.'cfg(not(target_os = "espidf"))'.dependencies] mbedtls = { git = "https://github.com/fortanix/rust-mbedtls", optional = true } env_logger = { version = "0.10.0", optional = true } -nix = { version = "0.26", features = ["net"] } +nix = { version = "0.26", features = ["net"], optional = true } [target.'cfg(target_os = "espidf")'.dependencies] esp-idf-sys = { version = "0.33", default-features = false, features = ["native"] } From 8e9d8887dad12f3e16d0626cbc846b406d4f75d9 Mon Sep 17 00:00:00 2001 From: ivmarkov Date: Thu, 1 Jun 2023 04:59:01 +0000 Subject: [PATCH 50/72] Fix a bug in mDNS --- matter/src/mdns.rs | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/matter/src/mdns.rs b/matter/src/mdns.rs index 7564578a..eaba7ee9 100644 --- a/matter/src/mdns.rs +++ b/matter/src/mdns.rs @@ -392,9 +392,9 @@ pub mod builtin { impl MdnsEntry { #[inline(always)] - const fn new() -> Self { + const fn new(key: heapless::String<64>) -> Self { Self { - key: heapless::String::new(), + key, record: heapless::Vec::new(), } } @@ -479,7 +479,7 @@ pub mod builtin { entries.retain(|entry| entry.key != key); entries - .push(MdnsEntry::new()) + .push(MdnsEntry::new(key)) .map_err(|_| ErrorCode::NoSpace)?; let entry = entries.iter_mut().last().unwrap(); From 1b879f1a5b4d4ed2c92650c0bac06e8bd7be147c Mon Sep 17 00:00:00 2001 From: ivmarkov Date: Fri, 9 Jun 2023 07:47:49 +0000 Subject: [PATCH 51/72] Simplify main user-facing API --- examples/onoff_light/src/main.rs | 29 ++++--- matter/src/core.rs | 53 +++++++++--- matter/src/data_model/root_endpoint.rs | 8 +- matter/src/secure_channel/core.rs | 30 +++++-- matter/src/secure_channel/pake.rs | 4 + matter/src/transport/{mgr.rs => core.rs} | 100 +++++++++++------------ matter/src/transport/exchange.rs | 1 + matter/src/transport/mod.rs | 2 +- matter/tests/common/im_engine.rs | 6 +- 9 files changed, 137 insertions(+), 96 deletions(-) rename matter/src/transport/{mgr.rs => core.rs} (76%) diff --git a/examples/onoff_light/src/main.rs b/examples/onoff_light/src/main.rs index 20b1cc9b..baf20223 100644 --- a/examples/onoff_light/src/main.rs +++ b/examples/onoff_light/src/main.rs @@ -27,7 +27,6 @@ use matter::data_model::core::DataModel; use matter::data_model::device_types::DEV_TYPE_ON_OFF_LIGHT; use matter::data_model::objects::*; use matter::data_model::root_endpoint; -use matter::data_model::sdm::dev_att::DevAttDataFetcher; use matter::data_model::system_model::descriptor; use matter::error::Error; use matter::interaction_model::core::InteractionModel; @@ -36,7 +35,7 @@ use matter::persist; use matter::secure_channel::spake2p::VerifierData; use matter::transport::network::{Address, IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr}; use matter::transport::{ - mgr::RecvAction, mgr::TransportMgr, packet::MAX_RX_BUF_SIZE, packet::MAX_TX_BUF_SIZE, + core::RecvAction, core::Transport, packet::MAX_RX_BUF_SIZE, packet::MAX_TX_BUF_SIZE, udp::UdpListener, }; use matter::utils::select::EitherUnwrap; @@ -66,7 +65,7 @@ fn run() -> Result<(), Error> { info!( "Matter memory: mDNS={}, Transport={} (of which Matter={})", core::mem::size_of::(), - core::mem::size_of::(), + core::mem::size_of::(), core::mem::size_of::(), ); @@ -83,9 +82,11 @@ fn run() -> Result<(), Error> { //let (mut mdns, mdns_runner) = (matter::mdns::astro::AstroMdns::new()?, core::future::pending::pending()); //let (mut mdns, mdns_runner) = (matter::mdns::DummyMdns {}, core::future::pending::pending()); + let dev_att = dev_att::HardCodedDevAtt::new(); + let matter = Matter::new_default( // vid/pid should match those in the DAC - &BasicInfoConfig { + BasicInfoConfig { vid: 0xFFF1, pid: 0x8000, hw_ver: 2, @@ -94,12 +95,11 @@ fn run() -> Result<(), Error> { serial_no: "aabbccdd", device_name: "OnOff Light", }, + &dev_att, &mut mdns, matter::MATTER_PORT, ); - let dev_att = dev_att::HardCodedDevAtt::new(); - let psm_path = std::env::temp_dir().join("matter-iot"); info!("Persisting from/to {}", psm_path.display()); @@ -115,7 +115,9 @@ fn run() -> Result<(), Error> { matter.load_fabrics(data)?; } - matter.start( + let mut transport = Transport::new(&matter); + + transport.start( CommissioningData { // TODO: Hard-coded for now verifier: VerifierData::new_with_pw(123456, *matter.borrow()), @@ -125,10 +127,7 @@ fn run() -> Result<(), Error> { )?; let matter = &matter; - let dev_att = &dev_att; let mdns_runner = &mut mdns_runner; - - let mut transport = TransportMgr::new(matter); let transport = &mut transport; let mut io_fut = pin!(async move { @@ -165,7 +164,7 @@ fn run() -> Result<(), Error> { ], }; - let mut handler = handler(matter, dev_att); + let mut handler = handler(matter); let mut im = InteractionModel(DataModel::new(matter.borrow(), &node, &mut handler)); @@ -180,11 +179,11 @@ fn run() -> Result<(), Error> { } } - if let Some(data) = matter.store_fabrics(&mut buf)? { + if let Some(data) = transport.matter().store_fabrics(&mut buf)? { psm.store("fabrics", data)?; } - if let Some(data) = matter.store_acls(&mut buf)? { + if let Some(data) = transport.matter().store_acls(&mut buf)? { psm.store("acls", data)?; } } @@ -202,8 +201,8 @@ fn run() -> Result<(), Error> { Ok::<_, matter::error::Error>(()) } -fn handler<'a>(matter: &'a Matter<'a>, dev_att: &'a dyn DevAttDataFetcher) -> impl Handler + 'a { - root_endpoint::handler(0, dev_att, matter) +fn handler<'a>(matter: &'a Matter<'a>) -> impl Handler + 'a { + root_endpoint::handler(0, matter) .chain( 1, descriptor::ID, diff --git a/matter/src/core.rs b/matter/src/core.rs index fa960f78..dacddbde 100644 --- a/matter/src/core.rs +++ b/matter/src/core.rs @@ -19,7 +19,10 @@ use core::{borrow::Borrow, cell::RefCell}; use crate::{ acl::AclMgr, - data_model::{cluster_basic_information::BasicInfoConfig, sdm::failsafe::FailSafe}, + data_model::{ + cluster_basic_information::BasicInfoConfig, + sdm::{dev_att::DevAttDataFetcher, failsafe::FailSafe}, + }, error::*, fabric::FabricMgr, mdns::{Mdns, MdnsMgr}, @@ -48,17 +51,24 @@ pub struct Matter<'a> { pub mdns_mgr: RefCell>, pub epoch: Epoch, pub rand: Rand, - pub dev_det: &'a BasicInfoConfig<'a>, + pub dev_det: BasicInfoConfig<'a>, + pub dev_att: &'a dyn DevAttDataFetcher, + pub port: u16, } impl<'a> Matter<'a> { #[cfg(feature = "std")] #[inline(always)] - pub fn new_default(dev_det: &'a BasicInfoConfig, mdns: &'a mut dyn Mdns, port: u16) -> Self { + pub fn new_default( + dev_det: BasicInfoConfig<'a>, + dev_att: &'a dyn DevAttDataFetcher, + mdns: &'a mut dyn Mdns, + port: u16, + ) -> Self { use crate::utils::epoch::sys_epoch; use crate::utils::rand::sys_rand; - Self::new(dev_det, mdns, sys_epoch, sys_rand, port) + Self::new(dev_det, dev_att, mdns, sys_epoch, sys_rand, port) } /// Creates a new Matter object @@ -69,7 +79,8 @@ impl<'a> Matter<'a> { /// this object to return the device attestation details when queried upon. #[inline(always)] pub fn new( - dev_det: &'a BasicInfoConfig, + dev_det: BasicInfoConfig<'a>, + dev_att: &'a dyn DevAttDataFetcher, mdns: &'a mut dyn Mdns, epoch: Epoch, rand: Rand, @@ -90,11 +101,21 @@ impl<'a> Matter<'a> { epoch, rand, dev_det, + dev_att, + port, } } - pub fn dev_det(&self) -> &BasicInfoConfig { - self.dev_det + pub fn dev_det(&self) -> &BasicInfoConfig<'_> { + &self.dev_det + } + + pub fn dev_att(&self) -> &dyn DevAttDataFetcher { + self.dev_att + } + + pub fn port(&self) -> u16 { + self.port } pub fn load_fabrics(&self, data: &[u8]) -> Result<(), Error> { @@ -119,11 +140,15 @@ impl<'a> Matter<'a> { self.acl_mgr.borrow().is_changed() || self.fabric_mgr.borrow().is_changed() } - pub fn start(&self, dev_comm: CommissioningData, buf: &mut [u8]) -> Result<(), Error> { - let open_comm_window = self.fabric_mgr.borrow().is_empty(); - if open_comm_window { + pub fn start_comissioning( + &self, + dev_comm: CommissioningData, + buf: &mut [u8], + ) -> Result { + if !self.pase_mgr.borrow().is_pase_session_enabled() && self.fabric_mgr.borrow().is_empty() + { print_pairing_code_and_qr( - self.dev_det, + &self.dev_det, &dev_comm, DiscoveryCapabilities::default(), buf, @@ -134,9 +159,11 @@ impl<'a> Matter<'a> { dev_comm.discriminator, &mut self.mdns_mgr.borrow_mut(), )?; - } - Ok(()) + Ok(true) + } else { + Ok(false) + } } } diff --git a/matter/src/data_model/root_endpoint.rs b/matter/src/data_model/root_endpoint.rs index 7ad87fb6..859d2bce 100644 --- a/matter/src/data_model/root_endpoint.rs +++ b/matter/src/data_model/root_endpoint.rs @@ -55,15 +55,11 @@ pub fn endpoint(id: EndptId) -> Endpoint<'static> { } } -pub fn handler<'a>( - endpoint_id: u16, - dev_att: &'a dyn DevAttDataFetcher, - matter: &'a Matter<'a>, -) -> RootEndpointHandler<'a> { +pub fn handler<'a>(endpoint_id: u16, matter: &'a Matter<'a>) -> RootEndpointHandler<'a> { wrap( endpoint_id, matter.dev_det(), - dev_att, + matter.dev_att(), matter.borrow(), matter.borrow(), matter.borrow(), diff --git a/matter/src/secure_channel/core.rs b/matter/src/secure_channel/core.rs index 7287ae00..21196917 100644 --- a/matter/src/secure_channel/core.rs +++ b/matter/src/secure_channel/core.rs @@ -15,7 +15,7 @@ * limitations under the License. */ -use core::cell::RefCell; +use core::{borrow::Borrow, cell::RefCell}; use crate::{ error::*, @@ -24,7 +24,7 @@ use crate::{ secure_channel::common::*, tlv, transport::{proto_ctx::ProtoCtx, session::CloneData}, - utils::rand::Rand, + utils::{epoch::Epoch, rand::Rand}, }; use log::{error, info}; use num; @@ -42,14 +42,32 @@ pub struct SecureChannel<'a> { impl<'a> SecureChannel<'a> { #[inline(always)] - pub fn new( + pub fn new< + T: Borrow> + + Borrow> + + Borrow>> + + Borrow + + Borrow, + >( + matter: &'a T, + ) -> Self { + Self::wrap( + matter.borrow(), + matter.borrow(), + matter.borrow(), + *matter.borrow(), + ) + } + + #[inline(always)] + pub fn wrap( pase: &'a RefCell, - fabric_mgr: &'a RefCell, + fabric: &'a RefCell, mdns: &'a RefCell>, rand: Rand, ) -> Self { - SecureChannel { - case: Case::new(fabric_mgr, rand), + Self { + case: Case::new(fabric, rand), pase, mdns, } diff --git a/matter/src/secure_channel/pake.rs b/matter/src/secure_channel/pake.rs index ab095122..b5a29a2d 100644 --- a/matter/src/secure_channel/pake.rs +++ b/matter/src/secure_channel/pake.rs @@ -59,6 +59,10 @@ impl PaseMgr { } } + pub fn is_pase_session_enabled(&self) -> bool { + matches!(&self.state, PaseMgrState::Enabled(_, _, _)) + } + pub fn enable_pase_session( &mut self, verifier: VerifierData, diff --git a/matter/src/transport/mgr.rs b/matter/src/transport/core.rs similarity index 76% rename from matter/src/transport/mgr.rs rename to matter/src/transport/core.rs index eeff6ff9..1d02bc0c 100644 --- a/matter/src/transport/mgr.rs +++ b/matter/src/transport/core.rs @@ -15,22 +15,14 @@ * limitations under the License. */ -use core::borrow::Borrow; -use core::cell::RefCell; - use log::info; -use crate::error::*; -use crate::fabric::FabricMgr; -use crate::mdns::MdnsMgr; -use crate::secure_channel::pake::PaseMgr; +use crate::{error::*, CommissioningData, Matter}; use crate::secure_channel::common::PROTO_ID_SECURE_CHANNEL; use crate::secure_channel::core::SecureChannel; use crate::transport::mrp::ReliableMessage; use crate::transport::{exchange, network::Address, packet::Packet}; -use crate::utils::epoch::Epoch; -use crate::utils::rand::Rand; use super::proto_ctx::ProtoCtx; use super::session::CloneData; @@ -50,7 +42,7 @@ pub enum RecvAction<'r, 'p> { } pub struct RecvCompletion<'r, 'a, 'p> { - mgr: &'r mut TransportMgr<'a>, + transport: &'r mut Transport<'a>, rx: Packet<'p>, tx: Packet<'p>, state: RecvState, @@ -69,20 +61,25 @@ impl<'r, 'a, 'p> RecvCompletion<'r, 'a, 'p> { } fn maybe_next_action(&mut self) -> Result>>, Error> { - self.mgr.exch_mgr.purge(); + self.transport.exch_mgr.purge(); self.tx.reset(); let (state, next) = match core::mem::replace(&mut self.state, RecvState::New) { RecvState::New => { - self.mgr.exch_mgr.get_sess_mgr().decode(&mut self.rx)?; + self.transport + .exch_mgr + .get_sess_mgr() + .decode(&mut self.rx)?; (RecvState::OpenExchange, None) } - RecvState::OpenExchange => match self.mgr.exch_mgr.recv(&mut self.rx) { + RecvState::OpenExchange => match self.transport.exch_mgr.recv(&mut self.rx) { Ok(Some(exch_ctx)) => { if self.rx.get_proto_id() == PROTO_ID_SECURE_CHANNEL { let mut proto_ctx = ProtoCtx::new(exch_ctx, &self.rx, &mut self.tx); - let (reply, clone_data) = self.mgr.secure_channel.handle(&mut proto_ctx)?; + let mut secure_channel = SecureChannel::new(self.transport.matter); + + let (reply, clone_data) = secure_channel.handle(&mut proto_ctx)?; let state = if let Some(clone_data) = clone_data { RecvState::AddSession(clone_data) @@ -115,15 +112,17 @@ impl<'r, 'a, 'p> RecvCompletion<'r, 'a, 'p> { _ => Err(e)?, }, }, - RecvState::AddSession(clone_data) => match self.mgr.exch_mgr.add_session(&clone_data) { - Ok(_) => (RecvState::Ack, None), - Err(e) => match e.code() { - ErrorCode::NoSpace => (RecvState::EvictSession2(clone_data), None), - _ => Err(e)?, - }, - }, + RecvState::AddSession(clone_data) => { + match self.transport.exch_mgr.add_session(&clone_data) { + Ok(_) => (RecvState::Ack, None), + Err(e) => match e.code() { + ErrorCode::NoSpace => (RecvState::EvictSession2(clone_data), None), + _ => Err(e)?, + }, + } + } RecvState::EvictSession => { - if self.mgr.exch_mgr.evict_session(&mut self.tx)? { + if self.transport.exch_mgr.evict_session(&mut self.tx)? { ( RecvState::OpenExchange, Some(Some(RecvAction::Send(self.tx.peer, self.tx.as_slice()))), @@ -133,7 +132,7 @@ impl<'r, 'a, 'p> RecvCompletion<'r, 'a, 'p> { } } RecvState::EvictSession2(clone_data) => { - if self.mgr.exch_mgr.evict_session(&mut self.tx)? { + if self.transport.exch_mgr.evict_session(&mut self.tx)? { ( RecvState::AddSession(clone_data), Some(Some(RecvAction::Send(self.tx.peer, self.tx.as_slice()))), @@ -143,12 +142,12 @@ impl<'r, 'a, 'p> RecvCompletion<'r, 'a, 'p> { } } RecvState::Ack => { - if let Some(exch_id) = self.mgr.exch_mgr.pending_ack() { + if let Some(exch_id) = self.transport.exch_mgr.pending_ack() { info!("Sending MRP Standalone ACK for exch {}", exch_id); ReliableMessage::prepare_ack(exch_id, &mut self.tx); - if self.mgr.exch_mgr.send(exch_id, &mut self.tx)? { + if self.transport.exch_mgr.send(exch_id, &mut self.tx)? { ( RecvState::Ack, Some(Some(RecvAction::Send(self.tx.peer, self.tx.as_slice()))), @@ -176,7 +175,7 @@ pub enum NotifyAction<'r, 'p> { pub struct NotifyCompletion<'r, 'a, 'p> { // TODO - _mgr: &'r mut TransportMgr<'a>, + _transport: &'r mut Transport<'a>, _rx: &'r mut Packet<'p>, _tx: &'r mut Packet<'p>, _state: NotifyState, @@ -199,40 +198,37 @@ impl<'r, 'a, 'p> NotifyCompletion<'r, 'a, 'p> { } } -pub struct TransportMgr<'a> { +pub struct Transport<'a> { + matter: &'a Matter<'a>, exch_mgr: exchange::ExchangeMgr, - secure_channel: SecureChannel<'a>, } -impl<'a> TransportMgr<'a> { - pub fn new< - T: Borrow> - + Borrow> - + Borrow>> - + Borrow - + Borrow, - >( - matter: &'a T, - ) -> Self { - Self::wrap( - SecureChannel::new( - matter.borrow(), - matter.borrow(), - matter.borrow(), - *matter.borrow(), - ), - *matter.borrow(), - *matter.borrow(), - ) - } +impl<'a> Transport<'a> { + #[inline(always)] + pub fn new(matter: &'a Matter<'a>) -> Self { + let epoch = matter.epoch; + let rand = matter.rand; - pub fn wrap(secure_channel: SecureChannel<'a>, epoch: Epoch, rand: Rand) -> Self { Self { + matter, exch_mgr: exchange::ExchangeMgr::new(epoch, rand), - secure_channel, } } + pub fn matter(&self) -> &Matter<'a> { + &self.matter + } + + pub fn start(&mut self, dev_comm: CommissioningData, buf: &mut [u8]) -> Result<(), Error> { + info!("Starting Matter transport"); + + if self.matter().start_comissioning(dev_comm, buf)? { + info!("Comissioning started"); + } + + Ok(()) + } + pub fn recv<'r, 'p>( &'r mut self, addr: Address, @@ -245,7 +241,7 @@ impl<'a> TransportMgr<'a> { rx.peer = addr; RecvCompletion { - mgr: self, + transport: self, rx, tx, state: RecvState::New, diff --git a/matter/src/transport/exchange.rs b/matter/src/transport/exchange.rs index 04b63db1..4910dbc0 100644 --- a/matter/src/transport/exchange.rs +++ b/matter/src/transport/exchange.rs @@ -279,6 +279,7 @@ pub struct ExchangeMgr { pub const MAX_MRP_ENTRIES: usize = 4; impl ExchangeMgr { + #[inline(always)] pub fn new(epoch: Epoch, rand: Rand) -> Self { Self { sess_mgr: SessionMgr::new(epoch, rand), diff --git a/matter/src/transport/mod.rs b/matter/src/transport/mod.rs index 1a81c75c..18957be5 100644 --- a/matter/src/transport/mod.rs +++ b/matter/src/transport/mod.rs @@ -15,9 +15,9 @@ * limitations under the License. */ +pub mod core; mod dedup; pub mod exchange; -pub mod mgr; pub mod mrp; pub mod network; pub mod packet; diff --git a/matter/tests/common/im_engine.rs b/matter/tests/common/im_engine.rs index ce608c50..70b2aca0 100644 --- a/matter/tests/common/im_engine.rs +++ b/matter/tests/common/im_engine.rs @@ -65,7 +65,7 @@ const BASIC_INFO: BasicInfoConfig<'static> = BasicInfoConfig { device_name: "Test Device", }; -pub struct DummyDevAtt {} +struct DummyDevAtt; impl DevAttDataFetcher for DummyDevAtt { fn get_devatt_data(&self, _data_type: DataType, _data: &mut [u8]) -> Result { @@ -110,7 +110,7 @@ pub fn matter(mdns: &mut dyn Mdns) -> Matter<'_> { #[cfg(not(feature = "std"))] use matter::utils::epoch::dummy_epoch as epoch; - Matter::new(&BASIC_INFO, mdns, epoch, dummy_rand, 5540) + Matter::new(BASIC_INFO, &DummyDevAtt, mdns, epoch, dummy_rand, 5540) } /// An Interaction Model Engine to facilitate easy testing @@ -161,7 +161,7 @@ impl<'a> ImEngine<'a> { }, ], }, - root_endpoint::handler(0, &DummyDevAtt {}, matter) + root_endpoint::handler(0, matter) .chain(0, echo_cluster::ID, EchoCluster::new(2, *matter.borrow())) .chain(1, descriptor::ID, DescriptorCluster::new(*matter.borrow())) .chain(1, echo_cluster::ID, EchoCluster::new(3, *matter.borrow())) From de3d3de004ac52af2426ce122e7b935382e1f19a Mon Sep 17 00:00:00 2001 From: ivmarkov Date: Sat, 10 Jun 2023 14:01:35 +0000 Subject: [PATCH 52/72] Make Matter covariant over its lifetime --- examples/onoff_light/src/main.rs | 88 +++---- matter/src/core.rs | 52 ++-- matter/src/data_model/root_endpoint.rs | 21 +- .../src/data_model/sdm/admin_commissioning.rs | 10 +- matter/src/data_model/sdm/noc.rs | 8 +- matter/src/fabric.rs | 6 +- matter/src/interaction_model/core.rs | 21 +- matter/src/mdns.rs | 226 +++++++++++------- matter/src/secure_channel/common.rs | 5 +- matter/src/secure_channel/core.rs | 16 +- matter/src/secure_channel/pake.rs | 6 +- matter/src/secure_channel/status_report.rs | 1 + matter/src/transport/core.rs | 39 ++- matter/src/transport/exchange.rs | 2 +- matter/src/transport/mod.rs | 1 + matter/src/transport/packet.rs | 107 ++++++++- matter/src/transport/pipe.rs | 94 ++++++++ matter/src/transport/session.rs | 75 +----- matter/src/utils/select.rs | 3 + matter/src/utils/writebuf.rs | 4 + matter/tests/common/echo_cluster.rs | 6 +- matter/tests/common/im_engine.rs | 4 +- 22 files changed, 512 insertions(+), 283 deletions(-) create mode 100644 matter/src/transport/pipe.rs diff --git a/examples/onoff_light/src/main.rs b/examples/onoff_light/src/main.rs index baf20223..b2e5091c 100644 --- a/examples/onoff_light/src/main.rs +++ b/examples/onoff_light/src/main.rs @@ -30,7 +30,7 @@ use matter::data_model::root_endpoint; use matter::data_model::system_model::descriptor; use matter::error::Error; use matter::interaction_model::core::InteractionModel; -use matter::mdns::builtin::Mdns; +use matter::mdns::builtin::{Mdns, MdnsRxBuf, MdnsTxBuf}; use matter::persist; use matter::secure_channel::spake2p::VerifierData; use matter::transport::network::{Address, IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr}; @@ -46,7 +46,7 @@ mod dev_att; fn main() -> Result<(), Error> { let thread = std::thread::Builder::new() .stack_size(120 * 1024) - .spawn(move || run()) + .spawn(run) .unwrap(); thread.join().unwrap() @@ -63,10 +63,10 @@ fn run() -> Result<(), Error> { initialize_logger(); info!( - "Matter memory: mDNS={}, Transport={} (of which Matter={})", + "Matter memory: mDNS={}, Matter={}, Transport={}", core::mem::size_of::(), - core::mem::size_of::(), core::mem::size_of::(), + core::mem::size_of::(), ); let (ipv4_addr, ipv6_addr) = initialize_network()?; @@ -78,7 +78,7 @@ fn run() -> Result<(), Error> { Some(ipv6_addr.octets()), ); - let (mut mdns, mut mdns_runner) = mdns.split(); + let (mdns, mut mdns_runner) = mdns.split(); //let (mut mdns, mdns_runner) = (matter::mdns::astro::AstroMdns::new()?, core::future::pending::pending()); //let (mut mdns, mdns_runner) = (matter::mdns::DummyMdns {}, core::future::pending::pending()); @@ -86,7 +86,7 @@ fn run() -> Result<(), Error> { let matter = Matter::new_default( // vid/pid should match those in the DAC - BasicInfoConfig { + &BasicInfoConfig { vid: 0xFFF1, pid: 0x8000, hw_ver: 2, @@ -96,7 +96,7 @@ fn run() -> Result<(), Error> { device_name: "OnOff Light", }, &dev_att, - &mut mdns, + &mdns, matter::MATTER_PORT, ); @@ -106,12 +106,13 @@ fn run() -> Result<(), Error> { let psm = persist::FilePsm::new(psm_path)?; let mut buf = [0; 4096]; + let buf = &mut buf; - if let Some(data) = psm.load("acls", &mut buf)? { + if let Some(data) = psm.load("acls", buf)? { matter.load_acls(data)?; } - if let Some(data) = psm.load("fabrics", &mut buf)? { + if let Some(data) = psm.load("fabrics", buf)? { matter.load_fabrics(data)?; } @@ -123,12 +124,33 @@ fn run() -> Result<(), Error> { verifier: VerifierData::new_with_pw(123456, *matter.borrow()), discriminator: 250, }, - &mut buf, + buf, )?; - let matter = &matter; + let node = Node { + id: 0, + endpoints: &[ + root_endpoint::endpoint(0), + Endpoint { + id: 1, + device_type: DEV_TYPE_ON_OFF_LIGHT, + clusters: &[descriptor::CLUSTER, cluster_on_off::CLUSTER], + }, + ], + }; + + let mut handler = handler(&matter); + + let mut im = InteractionModel(DataModel::new(matter.borrow(), &node, &mut handler)); + + let mut rx_buf = [0; MAX_RX_BUF_SIZE]; + let mut tx_buf = [0; MAX_TX_BUF_SIZE]; + + let im = &mut im; let mdns_runner = &mut mdns_runner; let transport = &mut transport; + let rx_buf = &mut rx_buf; + let tx_buf = &mut tx_buf; let mut io_fut = pin!(async move { let udp = UdpListener::new(SocketAddr::new( @@ -138,13 +160,9 @@ fn run() -> Result<(), Error> { .await?; loop { - let mut rx_buf = [0; MAX_RX_BUF_SIZE]; - let mut tx_buf = [0; MAX_TX_BUF_SIZE]; + let (len, addr) = udp.recv(rx_buf).await?; - let (len, addr) = udp.recv(&mut rx_buf).await?; - - let mut completion = - transport.recv(Address::Udp(addr), &mut rx_buf[..len], &mut tx_buf); + let mut completion = transport.recv(Address::Udp(addr), &mut rx_buf[..len], tx_buf); while let Some(action) = completion.next_action()? { match action { @@ -152,38 +170,19 @@ fn run() -> Result<(), Error> { udp.send(addr.unwrap_udp(), buf).await?; } RecvAction::Interact(mut ctx) => { - let node = Node { - id: 0, - endpoints: &[ - root_endpoint::endpoint(0), - Endpoint { - id: 1, - device_type: DEV_TYPE_ON_OFF_LIGHT, - clusters: &[descriptor::CLUSTER, cluster_on_off::CLUSTER], - }, - ], - }; - - let mut handler = handler(matter); - - let mut im = - InteractionModel(DataModel::new(matter.borrow(), &node, &mut handler)); - - if im.handle(&mut ctx)? { - if ctx.send()? { - udp.send(ctx.tx.peer.unwrap_udp(), ctx.tx.as_slice()) - .await?; - } + if im.handle(&mut ctx)? && ctx.send()? { + udp.send(ctx.tx.peer.unwrap_udp(), ctx.tx.as_slice()) + .await?; } } } } - if let Some(data) = transport.matter().store_fabrics(&mut buf)? { + if let Some(data) = transport.matter().store_fabrics(buf)? { psm.store("fabrics", data)?; } - if let Some(data) = transport.matter().store_acls(&mut buf)? { + if let Some(data) = transport.matter().store_acls(buf)? { psm.store("acls", data)?; } } @@ -192,7 +191,12 @@ fn run() -> Result<(), Error> { Ok::<_, matter::error::Error>(()) }); - let mut mdns_fut = pin!(async move { mdns_runner.run().await }); + let mut tx_buf = MdnsTxBuf::uninit(); + let mut rx_buf = MdnsRxBuf::uninit(); + let tx_buf = &mut tx_buf; + let rx_buf = &mut rx_buf; + + let mut mdns_fut = pin!(async move { mdns_runner.run_udp(tx_buf, rx_buf).await }); let mut fut = pin!(async move { select(&mut io_fut, &mut mdns_fut,).await.unwrap() }); diff --git a/matter/src/core.rs b/matter/src/core.rs index dacddbde..c6c5dc10 100644 --- a/matter/src/core.rs +++ b/matter/src/core.rs @@ -48,10 +48,10 @@ pub struct Matter<'a> { pub acl_mgr: RefCell, pub pase_mgr: RefCell, pub failsafe: RefCell, - pub mdns_mgr: RefCell>, + pub mdns_mgr: MdnsMgr<'a>, pub epoch: Epoch, pub rand: Rand, - pub dev_det: BasicInfoConfig<'a>, + pub dev_det: &'a BasicInfoConfig<'a>, pub dev_att: &'a dyn DevAttDataFetcher, pub port: u16, } @@ -60,9 +60,9 @@ impl<'a> Matter<'a> { #[cfg(feature = "std")] #[inline(always)] pub fn new_default( - dev_det: BasicInfoConfig<'a>, + dev_det: &'a BasicInfoConfig<'a>, dev_att: &'a dyn DevAttDataFetcher, - mdns: &'a mut dyn Mdns, + mdns: &'a dyn Mdns, port: u16, ) -> Self { use crate::utils::epoch::sys_epoch; @@ -79,9 +79,9 @@ impl<'a> Matter<'a> { /// this object to return the device attestation details when queried upon. #[inline(always)] pub fn new( - dev_det: BasicInfoConfig<'a>, + dev_det: &'a BasicInfoConfig<'a>, dev_att: &'a dyn DevAttDataFetcher, - mdns: &'a mut dyn Mdns, + mdns: &'a dyn Mdns, epoch: Epoch, rand: Rand, port: u16, @@ -91,13 +91,7 @@ impl<'a> Matter<'a> { acl_mgr: RefCell::new(AclMgr::new()), pase_mgr: RefCell::new(PaseMgr::new(epoch, rand)), failsafe: RefCell::new(FailSafe::new()), - mdns_mgr: RefCell::new(MdnsMgr::new( - dev_det.vid, - dev_det.pid, - dev_det.device_name, - port, - mdns, - )), + mdns_mgr: MdnsMgr::new(dev_det.vid, dev_det.pid, dev_det.device_name, port, mdns), epoch, rand, dev_det, @@ -107,7 +101,7 @@ impl<'a> Matter<'a> { } pub fn dev_det(&self) -> &BasicInfoConfig<'_> { - &self.dev_det + self.dev_det } pub fn dev_att(&self) -> &dyn DevAttDataFetcher { @@ -119,9 +113,7 @@ impl<'a> Matter<'a> { } pub fn load_fabrics(&self, data: &[u8]) -> Result<(), Error> { - self.fabric_mgr - .borrow_mut() - .load(data, &mut self.mdns_mgr.borrow_mut()) + self.fabric_mgr.borrow_mut().load(data, &self.mdns_mgr) } pub fn load_acls(&self, data: &[u8]) -> Result<(), Error> { @@ -148,7 +140,7 @@ impl<'a> Matter<'a> { if !self.pase_mgr.borrow().is_pase_session_enabled() && self.fabric_mgr.borrow().is_empty() { print_pairing_code_and_qr( - &self.dev_det, + self.dev_det, &dev_comm, DiscoveryCapabilities::default(), buf, @@ -157,7 +149,7 @@ impl<'a> Matter<'a> { self.pase_mgr.borrow_mut().enable_pase_session( dev_comm.verifier, dev_comm.discriminator, - &mut self.mdns_mgr.borrow_mut(), + &self.mdns_mgr, )?; Ok(true) @@ -191,12 +183,30 @@ impl<'a> Borrow> for Matter<'a> { } } -impl<'a> Borrow>> for Matter<'a> { - fn borrow(&self) -> &RefCell> { +impl<'a> Borrow> for Matter<'a> { + fn borrow(&self) -> &MdnsMgr<'a> { &self.mdns_mgr } } +impl<'a> Borrow> for Matter<'a> { + fn borrow(&self) -> &BasicInfoConfig<'a> { + self.dev_det + } +} + +impl<'a> Borrow for Matter<'a> { + fn borrow(&self) -> &(dyn DevAttDataFetcher + 'a) { + self.dev_att + } +} + +impl<'a> Borrow for Matter<'a> { + fn borrow(&self) -> &(dyn Mdns + 'a) { + self.mdns_mgr.mdns + } +} + impl<'a> Borrow for Matter<'a> { fn borrow(&self) -> &Epoch { &self.epoch diff --git a/matter/src/data_model/root_endpoint.rs b/matter/src/data_model/root_endpoint.rs index 859d2bce..1bc22fe2 100644 --- a/matter/src/data_model/root_endpoint.rs +++ b/matter/src/data_model/root_endpoint.rs @@ -7,7 +7,6 @@ use crate::{ mdns::MdnsMgr, secure_channel::pake::PaseMgr, utils::{epoch::Epoch, rand::Rand}, - Matter, }; use super::{ @@ -55,11 +54,23 @@ pub fn endpoint(id: EndptId) -> Endpoint<'static> { } } -pub fn handler<'a>(endpoint_id: u16, matter: &'a Matter<'a>) -> RootEndpointHandler<'a> { +pub fn handler<'a, T>(endpoint_id: u16, matter: &'a T) -> RootEndpointHandler<'a> +where + T: Borrow> + + Borrow + + Borrow> + + Borrow> + + Borrow> + + Borrow> + + Borrow> + + Borrow + + Borrow + + 'a, +{ wrap( endpoint_id, - matter.dev_det(), - matter.dev_att(), + matter.borrow(), + matter.borrow(), matter.borrow(), matter.borrow(), matter.borrow(), @@ -79,7 +90,7 @@ pub fn wrap<'a>( fabric: &'a RefCell, acl: &'a RefCell, failsafe: &'a RefCell, - mdns_mgr: &'a RefCell>, + mdns_mgr: &'a MdnsMgr<'a>, epoch: Epoch, rand: Rand, ) -> RootEndpointHandler<'a> { diff --git a/matter/src/data_model/sdm/admin_commissioning.rs b/matter/src/data_model/sdm/admin_commissioning.rs index b63aa2e1..93643115 100644 --- a/matter/src/data_model/sdm/admin_commissioning.rs +++ b/matter/src/data_model/sdm/admin_commissioning.rs @@ -102,15 +102,11 @@ pub struct OpenCommWindowReq<'a> { pub struct AdminCommCluster<'a> { data_ver: Dataver, pase_mgr: &'a RefCell, - mdns_mgr: &'a RefCell>, + mdns_mgr: &'a MdnsMgr<'a>, } impl<'a> AdminCommCluster<'a> { - pub fn new( - pase_mgr: &'a RefCell, - mdns_mgr: &'a RefCell>, - rand: Rand, - ) -> Self { + pub fn new(pase_mgr: &'a RefCell, mdns_mgr: &'a MdnsMgr<'a>, rand: Rand) -> Self { Self { data_ver: Dataver::new(rand), pase_mgr, @@ -159,7 +155,7 @@ impl<'a> AdminCommCluster<'a> { self.pase_mgr.borrow_mut().enable_pase_session( verifier, req.discriminator, - &mut self.mdns_mgr.borrow_mut(), + self.mdns_mgr, )?; Ok(()) diff --git a/matter/src/data_model/sdm/noc.rs b/matter/src/data_model/sdm/noc.rs index b8dda3cf..f347b13c 100644 --- a/matter/src/data_model/sdm/noc.rs +++ b/matter/src/data_model/sdm/noc.rs @@ -222,7 +222,7 @@ pub struct NocCluster<'a> { fabric_mgr: &'a RefCell, acl_mgr: &'a RefCell, failsafe: &'a RefCell, - mdns_mgr: &'a RefCell>, + mdns_mgr: &'a MdnsMgr<'a>, } impl<'a> NocCluster<'a> { @@ -231,7 +231,7 @@ impl<'a> NocCluster<'a> { fabric_mgr: &'a RefCell, acl_mgr: &'a RefCell, failsafe: &'a RefCell, - mdns_mgr: &'a RefCell>, + mdns_mgr: &'a MdnsMgr<'a>, epoch: Epoch, rand: Rand, ) -> Self { @@ -383,7 +383,7 @@ impl<'a> NocCluster<'a> { let fab_idx = self .fabric_mgr .borrow_mut() - .add(fabric, &mut self.mdns_mgr.borrow_mut()) + .add(fabric, self.mdns_mgr) .map_err(|_| NocStatus::TableFull)?; self.add_acl(fab_idx, r.case_admin_subject)?; @@ -455,7 +455,7 @@ impl<'a> NocCluster<'a> { if self .fabric_mgr .borrow_mut() - .remove(req.fab_idx, &mut self.mdns_mgr.borrow_mut()) + .remove(req.fab_idx, self.mdns_mgr) .is_ok() { let _ = self.acl_mgr.borrow_mut().delete_for_fabric(req.fab_idx); diff --git a/matter/src/fabric.rs b/matter/src/fabric.rs index f6f64ef7..04369ca0 100644 --- a/matter/src/fabric.rs +++ b/matter/src/fabric.rs @@ -200,7 +200,7 @@ impl FabricMgr { } } - pub fn load(&mut self, data: &[u8], mdns_mgr: &mut MdnsMgr) -> Result<(), Error> { + pub fn load(&mut self, data: &[u8], mdns_mgr: &MdnsMgr) -> Result<(), Error> { for fabric in self.fabrics.iter().flatten() { mdns_mgr.unpublish_service(&fabric.mdns_service_name, ServiceMode::Commissioned)?; } @@ -241,7 +241,7 @@ impl FabricMgr { self.changed } - pub fn add(&mut self, f: Fabric, mdns_mgr: &mut MdnsMgr) -> Result { + pub fn add(&mut self, f: Fabric, mdns_mgr: &MdnsMgr) -> Result { let slot = self.fabrics.iter().position(|x| x.is_none()); if slot.is_some() || self.fabrics.len() < MAX_SUPPORTED_FABRICS { @@ -265,7 +265,7 @@ impl FabricMgr { } } - pub fn remove(&mut self, fab_idx: u8, mdns_mgr: &mut MdnsMgr) -> Result<(), Error> { + pub fn remove(&mut self, fab_idx: u8, mdns_mgr: &MdnsMgr) -> Result<(), Error> { if fab_idx > 0 && fab_idx as usize <= self.fabrics.len() { if let Some(f) = self.fabrics[(fab_idx - 1) as usize].take() { mdns_mgr.unpublish_service(&f.mdns_service_name, ServiceMode::Commissioned)?; diff --git a/matter/src/interaction_model/core.rs b/matter/src/interaction_model/core.rs index 0686061c..cc763a84 100644 --- a/matter/src/interaction_model/core.rs +++ b/matter/src/interaction_model/core.rs @@ -204,7 +204,7 @@ impl<'a, 'b> Transaction<'a, 'b> { } /* Interaction Model ID as per the Matter Spec */ -const PROTO_ID_INTERACTION_MODEL: usize = 0x01; +pub const PROTO_ID_INTERACTION_MODEL: u16 = 0x01; const MAX_RESUME_PATHS: usize = 32; const MAX_RESUME_DATAVER_FILTERS: usize = 32; @@ -228,8 +228,7 @@ pub enum Interaction<'a> { impl<'a> Interaction<'a> { fn new(rx: &'a Packet, transaction: &mut Transaction) -> Result, Error> { - let opcode: OpCode = - num::FromPrimitive::from_u8(rx.get_proto_opcode()).ok_or(ErrorCode::Invalid)?; + let opcode: OpCode = rx.get_proto_opcode()?; let rx_data = rx.as_slice(); @@ -303,7 +302,7 @@ impl<'a> Interaction<'a> { } fn create_status_response(tx: &mut Packet, status: IMStatusCode) -> Result<(), Error> { - tx.set_proto_id(PROTO_ID_INTERACTION_MODEL as u16); + tx.set_proto_id(PROTO_ID_INTERACTION_MODEL); tx.set_proto_opcode(OpCode::StatusResponse as u8); let mut tw = TLVWriter::new(tx.get_writebuf()?); @@ -332,7 +331,7 @@ impl<'a> ReadReq<'a> { } fn initiate(&self, tx: &mut Packet, _transaction: &mut Transaction) -> Result { - tx.set_proto_id(PROTO_ID_INTERACTION_MODEL as u16); + tx.set_proto_id(PROTO_ID_INTERACTION_MODEL); tx.set_proto_opcode(OpCode::ReportData as u8); let mut tw = Self::reserve_long_read_space(tx)?; @@ -410,7 +409,7 @@ impl<'a> WriteReq<'a> { Ok(false) } else { - tx.set_proto_id(PROTO_ID_INTERACTION_MODEL as u16); + tx.set_proto_id(PROTO_ID_INTERACTION_MODEL); tx.set_proto_opcode(OpCode::WriteResponse as u8); let mut tw = TLVWriter::new(tx.get_writebuf()?); @@ -459,7 +458,7 @@ impl<'a> InvReq<'a> { Ok(false) } else { - tx.set_proto_id(PROTO_ID_INTERACTION_MODEL as u16); + tx.set_proto_id(PROTO_ID_INTERACTION_MODEL); tx.set_proto_opcode(OpCode::InvokeResponse as u8); let mut tw = TLVWriter::new(tx.get_writebuf()?); @@ -503,7 +502,7 @@ impl<'a> InvReq<'a> { impl TimedReq { pub fn process(&self, tx: &mut Packet, transaction: &mut Transaction) -> Result<(), Error> { - tx.set_proto_id(PROTO_ID_INTERACTION_MODEL as u16); + tx.set_proto_id(PROTO_ID_INTERACTION_MODEL); tx.set_proto_opcode(OpCode::StatusResponse as u8); let mut tw = TLVWriter::new(tx.get_writebuf()?); @@ -547,7 +546,7 @@ impl<'a> SubscribeReq<'a> { } fn initiate(&self, tx: &mut Packet, transaction: &mut Transaction) -> Result { - tx.set_proto_id(PROTO_ID_INTERACTION_MODEL as u16); + tx.set_proto_id(PROTO_ID_INTERACTION_MODEL); tx.set_proto_opcode(OpCode::ReportData as u8); let mut tw = ReadReq::reserve_long_read_space(tx)?; @@ -615,7 +614,7 @@ pub struct ResumeReadReq { impl ResumeReadReq { fn initiate(&self, tx: &mut Packet, _transaction: &mut Transaction) -> Result { - tx.set_proto_id(PROTO_ID_INTERACTION_MODEL as u16); + tx.set_proto_id(PROTO_ID_INTERACTION_MODEL); tx.set_proto_opcode(OpCode::ReportData as u8); let mut tw = ReadReq::reserve_long_read_space(tx)?; @@ -679,7 +678,7 @@ pub struct ResumeSubscribeReq { impl ResumeSubscribeReq { fn initiate(&self, tx: &mut Packet, _transaction: &mut Transaction) -> Result { - tx.set_proto_id(PROTO_ID_INTERACTION_MODEL as u16); + tx.set_proto_id(PROTO_ID_INTERACTION_MODEL); if self.resume_path.is_some() { tx.set_proto_opcode(OpCode::ReportData as u8); diff --git a/matter/src/mdns.rs b/matter/src/mdns.rs index eaba7ee9..3bc4a1db 100644 --- a/matter/src/mdns.rs +++ b/matter/src/mdns.rs @@ -21,7 +21,7 @@ use crate::error::Error; pub trait Mdns { fn add( - &mut self, + &self, name: &str, service: &str, protocol: &str, @@ -30,8 +30,7 @@ pub trait Mdns { txt_kvs: &[(&str, &str)], ) -> Result<(), Error>; - fn remove(&mut self, name: &str, service: &str, protocol: &str, port: u16) - -> Result<(), Error>; + fn remove(&self, name: &str, service: &str, protocol: &str, port: u16) -> Result<(), Error>; } impl Mdns for &mut T @@ -39,7 +38,7 @@ where T: Mdns, { fn add( - &mut self, + &self, name: &str, service: &str, protocol: &str, @@ -50,13 +49,7 @@ where (**self).add(name, service, protocol, port, service_subtypes, txt_kvs) } - fn remove( - &mut self, - name: &str, - service: &str, - protocol: &str, - port: u16, - ) -> Result<(), Error> { + fn remove(&self, name: &str, service: &str, protocol: &str, port: u16) -> Result<(), Error> { (**self).remove(name, service, protocol, port) } } @@ -65,7 +58,7 @@ pub struct DummyMdns; impl Mdns for DummyMdns { fn add( - &mut self, + &self, _name: &str, _service: &str, _protocol: &str, @@ -77,7 +70,7 @@ impl Mdns for DummyMdns { } fn remove( - &mut self, + &self, _name: &str, _service: &str, _protocol: &str, @@ -101,11 +94,11 @@ pub struct MdnsMgr<'a> { /// Product ID pid: u16, /// Device name - device_name: heapless::String<32>, + device_name: &'a str, /// Matter port matter_port: u16, /// mDns service - mdns: &'a mut dyn Mdns, + pub(crate) mdns: &'a dyn Mdns, } impl<'a> MdnsMgr<'a> { @@ -113,14 +106,14 @@ impl<'a> MdnsMgr<'a> { pub fn new( vid: u16, pid: u16, - device_name: &str, + device_name: &'a str, matter_port: u16, - mdns: &'a mut dyn Mdns, + mdns: &'a dyn Mdns, ) -> Self { Self { vid, pid, - device_name: device_name.chars().take(32).collect(), + device_name, matter_port, mdns, } @@ -130,7 +123,7 @@ impl<'a> MdnsMgr<'a> { /// name - is the service name (comma separated subtypes may follow) /// mode - the current service mode #[allow(clippy::needless_pass_by_value)] - pub fn publish_service(&mut self, name: &str, mode: ServiceMode) -> Result<(), Error> { + pub fn publish_service(&self, name: &str, mode: ServiceMode) -> Result<(), Error> { match mode { ServiceMode::Commissioned => { self.mdns @@ -143,7 +136,7 @@ impl<'a> MdnsMgr<'a> { let txt_kvs = [ ("D", discriminator_str.as_str()), ("CM", "1"), - ("DN", self.device_name.as_str()), + ("DN", self.device_name), ("VP", &vp), ("SII", "5000"), /* Sleepy Idle Interval */ ("SAI", "300"), /* Sleepy Active Interval */ @@ -166,7 +159,7 @@ impl<'a> MdnsMgr<'a> { } } - pub fn unpublish_service(&mut self, name: &str, mode: ServiceMode) -> Result<(), Error> { + pub fn unpublish_service(&self, name: &str, mode: ServiceMode) -> Result<(), Error> { match mode { ServiceMode::Commissioned => { self.mdns.remove(name, "_matter", "_tcp", self.matter_port) @@ -216,6 +209,7 @@ impl<'a> MdnsMgr<'a> { pub mod builtin { use core::cell::RefCell; use core::fmt::Write; + use core::mem::MaybeUninit; use core::pin::pin; use core::str::FromStr; @@ -224,15 +218,16 @@ pub mod builtin { use domain::base::octets::{Octets256, Octets64, OctetsBuilder}; use domain::base::{Dname, MessageBuilder, Record, ShortBuf}; use domain::rdata::{Aaaa, Ptr, Srv, Txt, A}; - use embassy_futures::select::select; - use embassy_sync::blocking_mutex::raw::NoopRawMutex; + use embassy_futures::select::{select, select3}; use embassy_time::{Duration, Timer}; use log::info; use crate::error::{Error, ErrorCode}; - use crate::transport::network::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr}; + use crate::transport::network::{Address, IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr}; + use crate::transport::packet::{MAX_RX_BUF_SIZE, MAX_TX_BUF_SIZE}; + use crate::transport::pipe::{Chunk, Pipe}; use crate::transport::udp::UdpListener; - use crate::utils::select::EitherUnwrap; + use crate::utils::select::{EitherUnwrap, Notification}; const IP_BROADCAST_ADDRS: [(IpAddr, u16); 2] = [ (IpAddr::V4(Ipv4Addr::new(224, 0, 0, 251)), 5353), @@ -244,6 +239,9 @@ pub mod builtin { const IP_BIND_ADDR: (IpAddr, u16) = (IpAddr::V6(Ipv6Addr::UNSPECIFIED), 5353); + pub type MdnsTxBuf = MaybeUninit<[u8; MAX_TX_BUF_SIZE]>; + pub type MdnsRxBuf = MaybeUninit<[u8; MAX_RX_BUF_SIZE]>; + #[allow(clippy::too_many_arguments)] pub fn create_record( id: u16, @@ -382,8 +380,6 @@ pub mod builtin { Ok(target.len()) } - pub type Notification = embassy_sync::signal::Signal; - #[derive(Debug, Clone)] struct MdnsEntry { key: heapless::String<64>, @@ -407,7 +403,6 @@ pub mod builtin { ipv6: Option<[u8; 16]>, entries: RefCell>, notification: Notification, - udp: RefCell>, } impl<'a> Mdns<'a> { @@ -420,7 +415,6 @@ pub mod builtin { ipv6, entries: RefCell::new(heapless::Vec::new()), notification: Notification::new(), - udp: RefCell::new(None), } } @@ -428,19 +422,6 @@ pub mod builtin { (MdnsApi(&*self), MdnsRunner(&*self)) } - async fn bind(&self) -> Result<(), Error> { - if self.udp.borrow().is_none() { - *self.udp.borrow_mut() = - Some(UdpListener::new(SocketAddr::new(IP_BIND_ADDR.0, IP_BIND_ADDR.1)).await?); - } - - Ok(()) - } - - pub fn close(&mut self) { - *self.udp.borrow_mut() = None; - } - fn key( &self, name: &str, @@ -546,15 +527,72 @@ pub mod builtin { pub struct MdnsRunner<'a, 'b>(&'a Mdns<'b>); impl<'a, 'b> MdnsRunner<'a, 'b> { - pub async fn run(&mut self) -> Result<(), Error> { - let mut broadcast = pin!(self.broadcast()); - let mut respond = pin!(self.respond()); + pub async fn run_udp( + &mut self, + tx_buf: &mut MdnsTxBuf, + rx_buf: &mut MdnsRxBuf, + ) -> Result<(), Error> { + let udp = UdpListener::new(SocketAddr::new(IP_BIND_ADDR.0, IP_BIND_ADDR.1)).await?; + + let tx_pipe = Pipe::new(unsafe { tx_buf.assume_init_mut() }); + let rx_pipe = Pipe::new(unsafe { rx_buf.assume_init_mut() }); + + let tx_pipe = &tx_pipe; + let rx_pipe = &rx_pipe; + let udp = &udp; + + let mut tx = pin!(async move { + loop { + { + let mut data = tx_pipe.data.lock().await; + + if let Some(chunk) = data.chunk { + udp.send(chunk.addr.unwrap_udp(), &data.buf[chunk.start..chunk.end]) + .await?; + data.chunk = None; + tx_pipe.data_consumed_notification.signal(()); + } + } + + tx_pipe.data_supplied_notification.wait().await; + } + }); + + let mut rx = pin!(async move { + loop { + { + let mut data = rx_pipe.data.lock().await; + + if data.chunk.is_none() { + let (len, addr) = udp.recv(data.buf).await?; + + data.chunk = Some(Chunk { + start: 0, + end: len, + addr: Address::Udp(addr), + }); + rx_pipe.data_supplied_notification.signal(()); + } + } + + rx_pipe.data_consumed_notification.wait().await; + } + }); + + let mut run = pin!(async move { self.run(tx_pipe, rx_pipe).await }); + + select3(&mut tx, &mut rx, &mut run).await.unwrap() + } + + pub async fn run(&mut self, tx_pipe: &Pipe<'_>, rx_pipe: &Pipe<'_>) -> Result<(), Error> { + let mut broadcast = pin!(self.broadcast(tx_pipe)); + let mut respond = pin!(self.respond(rx_pipe, tx_pipe)); select(&mut broadcast, &mut respond).await.unwrap() } #[allow(clippy::await_holding_refcell_ref)] - async fn broadcast(&self) -> Result<(), Error> { + async fn broadcast(&self, tx_pipe: &Pipe<'_>) -> Result<(), Error> { loop { select( self.0.notification.wait(), @@ -564,51 +602,74 @@ pub mod builtin { let mut index = 0; - loop { - let entry = self.0.entries.borrow().get(index).cloned(); - - if let Some(entry) = entry { - info!("Broadasting mDNS entry {}", &entry.key); - - self.0.bind().await?; - - let udp = self.0.udp.borrow(); - let udp = udp.as_ref().unwrap(); - - for (addr, port) in IP_BROADCAST_ADDRS { - udp.send(SocketAddr::new(addr, port), &entry.record).await?; + 'outer: loop { + for (addr, port) in IP_BROADCAST_ADDRS { + loop { + { + let mut data = tx_pipe.data.lock().await; + + if data.chunk.is_none() { + let entries = self.0.entries.borrow(); + let entry = entries.get(index); + + if let Some(entry) = entry { + info!( + "Broadasting mDNS entry {} on {}:{}", + &entry.key, addr, port + ); + + let len = entry.record.len(); + data.buf[..len].copy_from_slice(&entry.record); + drop(entries); + + data.chunk = Some(Chunk { + start: 0, + end: len, + addr: Address::Udp(SocketAddr::new(addr, port)), + }); + + tx_pipe.data_supplied_notification.signal(()); + } else { + break 'outer; + } + + break; + } + } + + tx_pipe.data_consumed_notification.wait().await; } - - index += 1; - } else { - break; } + + index += 1; } } } #[allow(clippy::await_holding_refcell_ref)] - async fn respond(&self) -> Result<(), Error> { + async fn respond(&self, rx_pipe: &Pipe<'_>, _tx_pipe: &Pipe<'_>) -> Result<(), Error> { loop { - let mut buf = [0; 1580]; + { + let mut data = rx_pipe.data.lock().await; - let udp = self.0.udp.borrow(); - let udp = udp.as_ref().unwrap(); + if let Some(_chunk) = data.chunk { + // TODO: Process the incoming packed and only answer what we are being queried about - let (_len, _addr) = udp.recv(&mut buf).await?; + data.chunk = None; + rx_pipe.data_consumed_notification.signal(()); - info!("Received UDP packet"); - - // TODO: Process the incoming packed and only answer what we are being queried about + self.0.notification.signal(()); + } + } - self.0.notification.signal(()); + rx_pipe.data_supplied_notification.wait().await; } } } impl<'a, 'b> super::Mdns for MdnsApi<'a, 'b> { fn add( - &mut self, + &self, name: &str, service: &str, protocol: &str, @@ -628,7 +689,7 @@ pub mod builtin { } fn remove( - &mut self, + &self, name: &str, service: &str, protocol: &str, @@ -641,6 +702,7 @@ pub mod builtin { #[cfg(all(feature = "std", feature = "astro-dnssd"))] pub mod astro { + use core::cell::RefCell; use std::collections::HashMap; use super::Mdns; @@ -657,18 +719,18 @@ pub mod astro { } pub struct AstroMdns { - services: HashMap, + services: RefCell>, } impl AstroMdns { pub fn new() -> Result { Ok(Self { - services: HashMap::new(), + services: RefCell::new(HashMap::new()), }) } pub fn add( - &mut self, + &self, name: &str, service: &str, protocol: &str, @@ -698,7 +760,7 @@ pub mod astro { let svc = builder.register().map_err(|_| ErrorCode::MdnsError)?; - self.services.insert( + self.services.borrow_mut().insert( ServiceId { name: name.into(), service: service.into(), @@ -712,7 +774,7 @@ pub mod astro { } pub fn remove( - &mut self, + &self, name: &str, service: &str, protocol: &str, @@ -725,7 +787,7 @@ pub mod astro { port, }; - if self.services.remove(&id).is_some() { + if self.services.borrow_mut().remove(&id).is_some() { info!( "Deregistering mDNS service {}/{}.{}/{}", name, service, protocol, port @@ -738,7 +800,7 @@ pub mod astro { impl Mdns for AstroMdns { fn add( - &mut self, + &self, name: &str, service: &str, protocol: &str, @@ -758,7 +820,7 @@ pub mod astro { } fn remove( - &mut self, + &self, name: &str, service: &str, protocol: &str, diff --git a/matter/src/secure_channel/common.rs b/matter/src/secure_channel/common.rs index c007ee5f..80fb7b51 100644 --- a/matter/src/secure_channel/common.rs +++ b/matter/src/secure_channel/common.rs @@ -24,7 +24,7 @@ use super::status_report::{create_status_report, GeneralCode}; /* Interaction Model ID as per the Matter Spec */ pub const PROTO_ID_SECURE_CHANNEL: u16 = 0x00; -#[derive(FromPrimitive, Debug)] +#[derive(FromPrimitive, Debug, Copy, Clone, Eq, PartialEq)] pub enum OpCode { MsgCounterSyncReq = 0x00, MsgCounterSyncResp = 0x01, @@ -56,8 +56,6 @@ pub fn create_sc_status_report( status_code: SCStatusCodes, proto_data: Option<&[u8]>, ) -> Result<(), Error> { - proto_tx.reset(); - let general_code = match status_code { SCStatusCodes::SessionEstablishmentSuccess => GeneralCode::Success, SCStatusCodes::CloseSession => { @@ -71,6 +69,7 @@ pub fn create_sc_status_report( | SCStatusCodes::NoSharedTrustRoots | SCStatusCodes::SessionNotFound => GeneralCode::Failure, }; + create_status_report( proto_tx, general_code, diff --git a/matter/src/secure_channel/core.rs b/matter/src/secure_channel/core.rs index 21196917..523278e6 100644 --- a/matter/src/secure_channel/core.rs +++ b/matter/src/secure_channel/core.rs @@ -27,7 +27,6 @@ use crate::{ utils::{epoch::Epoch, rand::Rand}, }; use log::{error, info}; -use num; use super::{case::Case, pake::PaseMgr}; @@ -37,7 +36,7 @@ use super::{case::Case, pake::PaseMgr}; pub struct SecureChannel<'a> { case: Case<'a>, pase: &'a RefCell, - mdns: &'a RefCell>, + mdns: &'a MdnsMgr<'a>, } impl<'a> SecureChannel<'a> { @@ -45,7 +44,7 @@ impl<'a> SecureChannel<'a> { pub fn new< T: Borrow> + Borrow> - + Borrow>> + + Borrow> + Borrow + Borrow, >( @@ -63,7 +62,7 @@ impl<'a> SecureChannel<'a> { pub fn wrap( pase: &'a RefCell, fabric: &'a RefCell, - mdns: &'a RefCell>, + mdns: &'a MdnsMgr<'a>, rand: Rand, ) -> Self { Self { @@ -74,8 +73,8 @@ impl<'a> SecureChannel<'a> { } pub fn handle(&mut self, ctx: &mut ProtoCtx) -> Result<(bool, Option), Error> { - let proto_opcode: OpCode = - num::FromPrimitive::from_u8(ctx.rx.get_proto_opcode()).ok_or(ErrorCode::Invalid)?; + let proto_opcode: OpCode = ctx.rx.get_proto_opcode()?; + ctx.tx.set_proto_id(PROTO_ID_SECURE_CHANNEL); info!("Received Opcode: {:?}", proto_opcode); info!("Received Data:"); @@ -92,10 +91,7 @@ impl<'a> SecureChannel<'a> { .borrow_mut() .pasepake1_handler(ctx) .map(|reply| (reply, None)), - OpCode::PASEPake3 => self - .pase - .borrow_mut() - .pasepake3_handler(ctx, &mut self.mdns.borrow_mut()), + OpCode::PASEPake3 => self.pase.borrow_mut().pasepake3_handler(ctx, self.mdns), OpCode::CASESigma1 => self.case.casesigma1_handler(ctx).map(|reply| (reply, None)), OpCode::CASESigma3 => self.case.casesigma3_handler(ctx), _ => { diff --git a/matter/src/secure_channel/pake.rs b/matter/src/secure_channel/pake.rs index b5a29a2d..60920d03 100644 --- a/matter/src/secure_channel/pake.rs +++ b/matter/src/secure_channel/pake.rs @@ -67,7 +67,7 @@ impl PaseMgr { &mut self, verifier: VerifierData, discriminator: u16, - mdns: &mut MdnsMgr, + mdns: &MdnsMgr, ) -> Result<(), Error> { let mut buf = [0; 8]; (self.rand)(&mut buf); @@ -89,7 +89,7 @@ impl PaseMgr { Ok(()) } - pub fn disable_pase_session(&mut self, mdns: &mut MdnsMgr) -> Result<(), Error> { + pub fn disable_pase_session(&mut self, mdns: &MdnsMgr) -> Result<(), Error> { if let PaseMgrState::Enabled(_, mdns_service_name, discriminator) = &self.state { mdns.unpublish_service( mdns_service_name, @@ -134,7 +134,7 @@ impl PaseMgr { pub fn pasepake3_handler( &mut self, ctx: &mut ProtoCtx, - mdns: &mut MdnsMgr, + mdns: &MdnsMgr, ) -> Result<(bool, Option), Error> { let clone_data = self.if_enabled(ctx, |pake, ctx| pake.handle_pasepake3(ctx))?; self.disable_pase_session(mdns)?; diff --git a/matter/src/secure_channel/status_report.rs b/matter/src/secure_channel/status_report.rs index 2f6aed13..e8378746 100644 --- a/matter/src/secure_channel/status_report.rs +++ b/matter/src/secure_channel/status_report.rs @@ -39,6 +39,7 @@ pub enum GeneralCode { PermissionDenied = 15, DataLoss = 16, } + pub fn create_status_report( proto_tx: &mut Packet, general_code: GeneralCode, diff --git a/matter/src/transport/core.rs b/matter/src/transport/core.rs index 1d02bc0c..1b169eec 100644 --- a/matter/src/transport/core.rs +++ b/matter/src/transport/core.rs @@ -41,15 +41,15 @@ pub enum RecvAction<'r, 'p> { Interact(ProtoCtx<'r, 'p>), } -pub struct RecvCompletion<'r, 'a, 'p> { +pub struct RecvCompletion<'r, 'a> { transport: &'r mut Transport<'a>, - rx: Packet<'p>, - tx: Packet<'p>, + rx: Packet<'r>, + tx: Packet<'r>, state: RecvState, } -impl<'r, 'a, 'p> RecvCompletion<'r, 'a, 'p> { - pub fn next_action(&mut self) -> Result>, Error> { +impl<'r, 'a> RecvCompletion<'r, 'a> { + pub fn next_action(&mut self) -> Result>, Error> { loop { // Polonius will remove the need for unsafe one day let this = unsafe { (self as *mut RecvCompletion).as_mut().unwrap() }; @@ -60,16 +60,13 @@ impl<'r, 'a, 'p> RecvCompletion<'r, 'a, 'p> { } } - fn maybe_next_action(&mut self) -> Result>>, Error> { + fn maybe_next_action(&mut self) -> Result>>, Error> { self.transport.exch_mgr.purge(); self.tx.reset(); let (state, next) = match core::mem::replace(&mut self.state, RecvState::New) { RecvState::New => { - self.transport - .exch_mgr - .get_sess_mgr() - .decode(&mut self.rx)?; + self.rx.plain_hdr_decode()?; (RecvState::OpenExchange, None) } RecvState::OpenExchange => match self.transport.exch_mgr.recv(&mut self.rx) { @@ -173,16 +170,16 @@ pub enum NotifyAction<'r, 'p> { Notify(ProtoCtx<'r, 'p>), } -pub struct NotifyCompletion<'r, 'a, 'p> { +pub struct NotifyCompletion<'r, 'a> { // TODO _transport: &'r mut Transport<'a>, - _rx: &'r mut Packet<'p>, - _tx: &'r mut Packet<'p>, + _rx: Packet<'r>, + _tx: Packet<'r>, _state: NotifyState, } -impl<'r, 'a, 'p> NotifyCompletion<'r, 'a, 'p> { - pub fn next_action(&mut self) -> Result>, Error> { +impl<'r, 'a> NotifyCompletion<'r, 'a> { + pub fn next_action(&mut self) -> Result>, Error> { loop { // Polonius will remove the need for unsafe one day let this = unsafe { (self as *mut NotifyCompletion).as_mut().unwrap() }; @@ -193,7 +190,7 @@ impl<'r, 'a, 'p> NotifyCompletion<'r, 'a, 'p> { } } - fn maybe_next_action(&mut self) -> Result>>, Error> { + fn maybe_next_action(&mut self) -> Result>>, Error> { Ok(Some(None)) // TODO: Future } } @@ -216,7 +213,7 @@ impl<'a> Transport<'a> { } pub fn matter(&self) -> &Matter<'a> { - &self.matter + self.matter } pub fn start(&mut self, dev_comm: CommissioningData, buf: &mut [u8]) -> Result<(), Error> { @@ -229,12 +226,12 @@ impl<'a> Transport<'a> { Ok(()) } - pub fn recv<'r, 'p>( + pub fn recv<'r>( &'r mut self, addr: Address, - rx_buf: &'p mut [u8], - tx_buf: &'p mut [u8], - ) -> RecvCompletion<'r, 'a, 'p> { + rx_buf: &'r mut [u8], + tx_buf: &'r mut [u8], + ) -> RecvCompletion<'r, 'a> { let mut rx = Packet::new_rx(rx_buf); let tx = Packet::new_tx(tx_buf); diff --git a/matter/src/transport/exchange.rs b/matter/src/transport/exchange.rs index 4910dbc0..5dbb1bba 100644 --- a/matter/src/transport/exchange.rs +++ b/matter/src/transport/exchange.rs @@ -223,7 +223,7 @@ impl Exchange { "{} with proto id: {} opcode: {}, tlv:\n", "Sending".blue(), tx.get_proto_id(), - tx.get_proto_opcode(), + tx.get_proto_raw_opcode(), ); //print_tlv_list(tx.as_slice()); diff --git a/matter/src/transport/mod.rs b/matter/src/transport/mod.rs index 18957be5..a219f165 100644 --- a/matter/src/transport/mod.rs +++ b/matter/src/transport/mod.rs @@ -21,6 +21,7 @@ pub mod exchange; pub mod mrp; pub mod network; pub mod packet; +pub mod pipe; pub mod plain_hdr; pub mod proto_ctx; pub mod proto_hdr; diff --git a/matter/src/transport/packet.rs b/matter/src/transport/packet.rs index 72368cb5..5e0cf98a 100644 --- a/matter/src/transport/packet.rs +++ b/matter/src/transport/packet.rs @@ -15,10 +15,14 @@ * limitations under the License. */ -use log::error; +use log::{error, info, trace}; +use owo_colors::OwoColorize; use crate::{ error::{Error, ErrorCode}, + interaction_model::core::PROTO_ID_INTERACTION_MODEL, + secure_channel::common::PROTO_ID_SECURE_CHANNEL, + tlv, utils::{parsebuf::ParseBuf, writebuf::WriteBuf}, }; @@ -29,6 +33,7 @@ use super::{ }; pub const MAX_RX_BUF_SIZE: usize = 1583; +pub const MAX_RX_STATUS_BUF_SIZE: usize = 100; pub const MAX_TX_BUF_SIZE: usize = 1280 - 40/*IPV6 header size*/ - 8/*UDP header size*/; #[derive(Debug, PartialEq, Eq, Copy, Clone)] @@ -160,10 +165,22 @@ impl<'a> Packet<'a> { self.proto.proto_id = proto_id; } - pub fn get_proto_opcode(&self) -> u8 { + pub fn get_proto_opcode(&self) -> Result { + num::FromPrimitive::from_u8(self.proto.proto_opcode).ok_or(ErrorCode::Invalid.into()) + } + + pub fn get_proto_raw_opcode(&self) -> u8 { self.proto.proto_opcode } + pub fn check_proto_opcode(&self, opcode: u8) -> Result<(), Error> { + if self.proto.proto_opcode == opcode { + Ok(()) + } else { + Err(ErrorCode::Invalid.into()) + } + } + pub fn set_proto_opcode(&mut self, proto_opcode: u8) { self.proto.proto_opcode = proto_opcode; } @@ -196,6 +213,52 @@ impl<'a> Packet<'a> { } } + pub fn proto_encode( + &mut self, + peer: Address, + peer_nodeid: Option, + local_nodeid: u64, + plain_text: bool, + enc_key: Option<&[u8]>, + ) -> Result<(), Error> { + self.peer = peer; + + // Generate encrypted header + let mut tmp_buf = [0_u8; proto_hdr::max_proto_hdr_len()]; + let mut write_buf = WriteBuf::new(&mut tmp_buf); + self.proto.encode(&mut write_buf)?; + self.get_writebuf()?.prepend(write_buf.as_slice())?; + + // Generate plain-text header + if plain_text { + if let Some(d) = peer_nodeid { + self.plain.set_dest_u64(d); + } + } + + let mut tmp_buf = [0_u8; plain_hdr::max_plain_hdr_len()]; + let mut write_buf = WriteBuf::new(&mut tmp_buf); + self.plain.encode(&mut write_buf)?; + let plain_hdr_bytes = write_buf.as_slice(); + + trace!("unencrypted packet: {:x?}", self.as_mut_slice()); + let ctr = self.plain.ctr; + if let Some(e) = enc_key { + proto_hdr::encrypt_in_place( + ctr, + local_nodeid, + plain_hdr_bytes, + self.get_writebuf()?, + e, + )?; + } + + self.get_writebuf()?.prepend(plain_hdr_bytes)?; + trace!("Full encrypted packet: {:x?}", self.as_mut_slice()); + + Ok(()) + } + pub fn is_plain_hdr_decoded(&self) -> Result { match &self.data { Direction::Rx(_, state) => match state { @@ -220,4 +283,44 @@ impl<'a> Packet<'a> { _ => Err(ErrorCode::InvalidState.into()), } } + + pub fn log(&self, operation: &str) { + match self.get_proto_id() { + PROTO_ID_SECURE_CHANNEL => { + if let Ok(opcode) = self.get_proto_opcode::() + { + info!("{} SC:{:?}: ", operation.cyan(), opcode); + } else { + info!( + "{} SC:{}??: ", + operation.cyan(), + self.get_proto_raw_opcode() + ); + } + + tlv::print_tlv_list(self.as_slice()); + } + PROTO_ID_INTERACTION_MODEL => { + if let Ok(opcode) = + self.get_proto_opcode::() + { + info!("{} IM:{:?}: ", operation.cyan(), opcode); + } else { + info!( + "{} IM:{}??: ", + operation.cyan(), + self.get_proto_raw_opcode() + ); + } + + tlv::print_tlv_list(self.as_slice()); + } + other => info!( + "{} {}??:{}??: ", + operation.cyan(), + other, + self.get_proto_raw_opcode() + ), + } + } } diff --git a/matter/src/transport/pipe.rs b/matter/src/transport/pipe.rs new file mode 100644 index 00000000..46259cc0 --- /dev/null +++ b/matter/src/transport/pipe.rs @@ -0,0 +1,94 @@ +/* + * + * Copyright (c) 2020-2022 Project CHIP Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +use embassy_sync::{blocking_mutex::raw::NoopRawMutex, mutex::Mutex}; + +use crate::utils::select::Notification; + +use super::network::Address; + +#[derive(Copy, Clone, Eq, PartialEq, Debug)] +pub struct Chunk { + pub start: usize, + pub end: usize, + pub addr: Address, +} + +pub struct PipeData<'a> { + pub buf: &'a mut [u8], + pub chunk: Option, +} + +pub struct Pipe<'a> { + pub data: Mutex>, + pub data_supplied_notification: Notification, + pub data_consumed_notification: Notification, +} + +impl<'a> Pipe<'a> { + #[inline(always)] + pub fn new(buf: &'a mut [u8]) -> Self { + Self { + data: Mutex::new(PipeData { buf, chunk: None }), + data_supplied_notification: Notification::new(), + data_consumed_notification: Notification::new(), + } + } + + pub async fn recv(&self, buf: &mut [u8]) -> (usize, Address) { + loop { + { + let mut data = self.data.lock().await; + + if let Some(chunk) = data.chunk { + buf[..chunk.end - chunk.start] + .copy_from_slice(&data.buf[chunk.start..chunk.end]); + data.chunk = None; + + self.data_consumed_notification.signal(()); + + return (chunk.end - chunk.start, chunk.addr); + } + } + + self.data_supplied_notification.wait().await + } + } + + pub async fn send(&self, addr: Address, buf: &[u8]) { + loop { + { + let mut data = self.data.lock().await; + + if data.chunk.is_none() { + data.buf[..buf.len()].copy_from_slice(buf); + data.chunk = Some(Chunk { + start: 0, + end: buf.len(), + addr, + }); + + self.data_supplied_notification.signal(()); + + break; + } + } + + self.data_consumed_notification.wait().await + } + } +} diff --git a/matter/src/transport/session.rs b/matter/src/transport/session.rs index 1c2e9365..c421244e 100644 --- a/matter/src/transport/session.rs +++ b/matter/src/transport/session.rs @@ -22,12 +22,8 @@ use core::fmt; use core::ops::{Deref, DerefMut}; use core::time::Duration; -use crate::{ - error::*, - transport::{plain_hdr, proto_hdr}, - utils::writebuf::WriteBuf, -}; -use log::{info, trace}; +use crate::{error::*, transport::plain_hdr}; +use log::info; use super::dedup::RxCtrState; use super::{network::Address, packet::Packet}; @@ -255,44 +251,16 @@ impl Session { Ok(()) } - // TODO: Most of this can now be moved into the 'Packet' module - fn do_send(&mut self, epoch: Epoch, tx: &mut Packet) -> Result<(), Error> { + fn send(&mut self, epoch: Epoch, tx: &mut Packet) -> Result<(), Error> { self.last_use = epoch(); - tx.peer = self.peer_addr; - - // Generate encrypted header - let mut tmp_buf = [0_u8; proto_hdr::max_proto_hdr_len()]; - let mut write_buf = WriteBuf::new(&mut tmp_buf); - tx.proto.encode(&mut write_buf)?; - tx.get_writebuf()?.prepend(write_buf.as_slice())?; - - // Generate plain-text header - if self.mode == SessionMode::PlainText { - if let Some(d) = self.peer_nodeid { - tx.plain.set_dest_u64(d); - } - } - let mut tmp_buf = [0_u8; plain_hdr::max_plain_hdr_len()]; - let mut write_buf = WriteBuf::new(&mut tmp_buf); - tx.plain.encode(&mut write_buf)?; - let plain_hdr_bytes = write_buf.as_slice(); - - trace!("unencrypted packet: {:x?}", tx.as_mut_slice()); - let ctr = tx.plain.ctr; - let enc_key = self.get_enc_key(); - if let Some(e) = enc_key { - proto_hdr::encrypt_in_place( - ctr, - self.local_nodeid, - plain_hdr_bytes, - tx.get_writebuf()?, - e, - )?; - } - tx.get_writebuf()?.prepend(plain_hdr_bytes)?; - trace!("Full encrypted packet: {:x?}", tx.as_mut_slice()); - Ok(()) + tx.proto_encode( + self.peer_addr, + self.peer_nodeid, + self.local_nodeid, + self.mode == SessionMode::PlainText, + self.get_enc_key(), + ) } fn rand_msg_ctr(rand: Rand) -> u32 { @@ -493,32 +461,11 @@ impl SessionMgr { } } - pub fn decode(&mut self, rx: &mut Packet) -> Result<(), Error> { - // let network = self.network.as_ref().ok_or(ErrorCode::NoNetworkInterface)?; - - // let (len, src) = network.recv(rx.as_borrow_slice()).await?; - // rx.get_parsebuf()?.set_len(len); - // rx.peer = src; - - // info!("{} from src: {}", "Received".blue(), src); - // trace!("payload: {:x?}", rx.as_borrow_slice()); - - // Read unencrypted packet header - rx.plain_hdr_decode() - } - pub fn send(&mut self, sess_idx: usize, tx: &mut Packet) -> Result<(), Error> { self.sessions[sess_idx] .as_mut() .ok_or(ErrorCode::NoSession)? - .do_send(self.epoch, tx)?; - - // let network = self.network.as_ref().ok_or(Error::NoNetworkInterface)?; - // let peer = proto_tx.peer; - // network.send(proto_tx.as_borrow_slice(), peer).await?; - // info!("Message Sent to {}", peer); - - Ok(()) + .send(self.epoch, tx) } pub fn get_session_handle(&mut self, sess_idx: usize) -> SessionHandle { diff --git a/matter/src/utils/select.rs b/matter/src/utils/select.rs index 2b5d21e9..a63c10be 100644 --- a/matter/src/utils/select.rs +++ b/matter/src/utils/select.rs @@ -1,4 +1,7 @@ use embassy_futures::select::{Either, Either3, Either4}; +use embassy_sync::blocking_mutex::raw::NoopRawMutex; + +pub type Notification = embassy_sync::signal::Signal; pub trait EitherUnwrap { fn unwrap(self) -> T; diff --git a/matter/src/utils/writebuf.rs b/matter/src/utils/writebuf.rs index 2f24c977..d091dfb5 100644 --- a/matter/src/utils/writebuf.rs +++ b/matter/src/utils/writebuf.rs @@ -38,6 +38,10 @@ impl<'a> WriteBuf<'a> { } } + pub fn get_start(&self) -> usize { + self.start + } + pub fn get_tail(&self) -> usize { self.end } diff --git a/matter/tests/common/echo_cluster.rs b/matter/tests/common/echo_cluster.rs index e5caca70..5e43e18a 100644 --- a/matter/tests/common/echo_cluster.rs +++ b/matter/tests/common/echo_cluster.rs @@ -24,8 +24,8 @@ use matter::{ attribute_enum, command_enum, data_model::objects::{ Access, AttrData, AttrDataEncoder, AttrDataWriter, AttrDetails, AttrType, Attribute, - Cluster, CmdDataEncoder, CmdDataWriter, CmdDetails, Dataver, Handler, Quality, - ATTRIBUTE_LIST, FEATURE_MAP, + Cluster, CmdDataEncoder, CmdDataWriter, CmdDetails, Dataver, Handler, NonBlockingHandler, + Quality, ATTRIBUTE_LIST, FEATURE_MAP, }, error::{Error, ErrorCode}, interaction_model::{ @@ -286,3 +286,5 @@ impl Handler for EchoCluster { EchoCluster::invoke(self, transaction, cmd, data, encoder) } } + +impl NonBlockingHandler for EchoCluster {} diff --git a/matter/tests/common/im_engine.rs b/matter/tests/common/im_engine.rs index 70b2aca0..13da8cd5 100644 --- a/matter/tests/common/im_engine.rs +++ b/matter/tests/common/im_engine.rs @@ -110,7 +110,7 @@ pub fn matter(mdns: &mut dyn Mdns) -> Matter<'_> { #[cfg(not(feature = "std"))] use matter::utils::epoch::dummy_epoch as epoch; - Matter::new(BASIC_INFO, &DummyDevAtt, mdns, epoch, dummy_rand, 5540) + Matter::new(&BASIC_INFO, &DummyDevAtt, mdns, epoch, dummy_rand, 5540) } /// An Interaction Model Engine to facilitate easy testing @@ -236,7 +236,7 @@ impl<'a> ImEngine<'a> { self.im.handle(&mut ctx).unwrap(); let out_data_len = ctx.tx.as_slice().len(); data_out[..out_data_len].copy_from_slice(ctx.tx.as_slice()); - let response = ctx.tx.get_proto_opcode(); + let response = ctx.tx.get_proto_raw_opcode(); (response, &data_out[..out_data_len]) } } From c0d1b85d9d54498cd70837765944b1b7d3a0263f Mon Sep 17 00:00:00 2001 From: ivmarkov Date: Sat, 10 Jun 2023 18:47:21 +0000 Subject: [PATCH 53/72] Default mDns impl --- examples/onoff_light/src/main.rs | 19 ++-- matter/src/mdns.rs | 167 ++++++++++++++++++++----------- 2 files changed, 116 insertions(+), 70 deletions(-) diff --git a/examples/onoff_light/src/main.rs b/examples/onoff_light/src/main.rs index b2e5091c..1bc99441 100644 --- a/examples/onoff_light/src/main.rs +++ b/examples/onoff_light/src/main.rs @@ -30,7 +30,7 @@ use matter::data_model::root_endpoint; use matter::data_model::system_model::descriptor; use matter::error::Error; use matter::interaction_model::core::InteractionModel; -use matter::mdns::builtin::{Mdns, MdnsRxBuf, MdnsTxBuf}; +use matter::mdns::{DefaultMdns, DefaultMdnsRunner}; use matter::persist; use matter::secure_channel::spake2p::VerifierData; use matter::transport::network::{Address, IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr}; @@ -64,23 +64,21 @@ fn run() -> Result<(), Error> { info!( "Matter memory: mDNS={}, Matter={}, Transport={}", - core::mem::size_of::(), + core::mem::size_of::(), core::mem::size_of::(), core::mem::size_of::(), ); let (ipv4_addr, ipv6_addr) = initialize_network()?; - let mut mdns = matter::mdns::builtin::Mdns::new( + let mdns = DefaultMdns::new( 0, "matter-demo", ipv4_addr.octets(), Some(ipv6_addr.octets()), ); - let (mdns, mut mdns_runner) = mdns.split(); - //let (mut mdns, mdns_runner) = (matter::mdns::astro::AstroMdns::new()?, core::future::pending::pending()); - //let (mut mdns, mdns_runner) = (matter::mdns::DummyMdns {}, core::future::pending::pending()); + let mut mdns_runner = DefaultMdnsRunner::new(&mdns); let dev_att = dev_att::HardCodedDevAtt::new(); @@ -191,14 +189,9 @@ fn run() -> Result<(), Error> { Ok::<_, matter::error::Error>(()) }); - let mut tx_buf = MdnsTxBuf::uninit(); - let mut rx_buf = MdnsRxBuf::uninit(); - let tx_buf = &mut tx_buf; - let rx_buf = &mut rx_buf; - - let mut mdns_fut = pin!(async move { mdns_runner.run_udp(tx_buf, rx_buf).await }); + let mut mdns_fut = pin!(async move { mdns_runner.run_udp().await }); - let mut fut = pin!(async move { select(&mut io_fut, &mut mdns_fut,).await.unwrap() }); + let mut fut = pin!(async move { select(&mut io_fut, &mut mdns_fut).await.unwrap() }); smol::block_on(&mut fut)?; diff --git a/matter/src/mdns.rs b/matter/src/mdns.rs index 3bc4a1db..5c831eef 100644 --- a/matter/src/mdns.rs +++ b/matter/src/mdns.rs @@ -54,6 +54,39 @@ where } } +impl Mdns for &T +where + T: Mdns, +{ + fn add( + &self, + name: &str, + service: &str, + protocol: &str, + port: u16, + service_subtypes: &[&str], + txt_kvs: &[(&str, &str)], + ) -> Result<(), Error> { + (**self).add(name, service, protocol, port, service_subtypes, txt_kvs) + } + + fn remove(&self, name: &str, service: &str, protocol: &str, port: u16) -> Result<(), Error> { + (**self).remove(name, service, protocol, port) + } +} + +#[cfg(all(feature = "std", feature = "astro-dnssd"))] +pub type DefaultMdns = astro::Mdns; + +#[cfg(all(feature = "std", feature = "astro-dnssd"))] +pub type DefaultMdnsRunner<'a> = astro::MdnsRunner<'a>; + +#[cfg(not(all(feature = "std", feature = "astro-dnssd")))] +pub type DefaultMdns<'a> = builtin::Mdns<'a>; + +#[cfg(not(all(feature = "std", feature = "astro-dnssd")))] +pub type DefaultMdnsRunner<'a> = builtin::MdnsRunner<'a>; + pub struct DummyMdns; impl Mdns for DummyMdns { @@ -239,8 +272,8 @@ pub mod builtin { const IP_BIND_ADDR: (IpAddr, u16) = (IpAddr::V6(Ipv6Addr::UNSPECIFIED), 5353); - pub type MdnsTxBuf = MaybeUninit<[u8; MAX_TX_BUF_SIZE]>; - pub type MdnsRxBuf = MaybeUninit<[u8; MAX_RX_BUF_SIZE]>; + type MdnsTxBuf = MaybeUninit<[u8; MAX_TX_BUF_SIZE]>; + type MdnsRxBuf = MaybeUninit<[u8; MAX_RX_BUF_SIZE]>; #[allow(clippy::too_many_arguments)] pub fn create_record( @@ -418,28 +451,6 @@ pub mod builtin { } } - pub fn split(&mut self) -> (MdnsApi<'_, 'a>, MdnsRunner<'_, 'a>) { - (MdnsApi(&*self), MdnsRunner(&*self)) - } - - fn key( - &self, - name: &str, - service: &str, - protocol: &str, - port: u16, - ) -> heapless::String<64> { - let mut key = heapless::String::new(); - - write!(&mut key, "{name}.{service}.{protocol}.{port}").unwrap(); - - key - } - } - - pub struct MdnsApi<'a, 'b>(&'a Mdns<'b>); - - impl<'a, 'b> MdnsApi<'a, 'b> { pub fn add( &self, name: &str, @@ -454,9 +465,9 @@ pub mod builtin { name, service, protocol, service_subtypes, port, txt_kvs ); - let key = self.0.key(name, service, protocol, port); + let key = self.key(name, service, protocol, port); - let mut entries = self.0.entries.borrow_mut(); + let mut entries = self.entries.borrow_mut(); entries.retain(|entry| entry.key != key); entries @@ -471,10 +482,10 @@ pub mod builtin { .unwrap(); match create_record( - self.0.id, - self.0.hostname, - self.0.ip, - self.0.ipv6, + self.id, + self.hostname, + self.ip, + self.ipv6, 60, /*ttl_sec*/ name, service, @@ -491,7 +502,7 @@ pub mod builtin { } } - self.0.notification.signal(()); + self.notification.signal(()); Ok(()) } @@ -508,37 +519,57 @@ pub mod builtin { name, service, protocol, port ); - let key = self.0.key(name, service, protocol, port); + let key = self.key(name, service, protocol, port); - let mut entries = self.0.entries.borrow_mut(); + let mut entries = self.entries.borrow_mut(); let old_len = entries.len(); entries.retain(|entry| entry.key != key); if entries.len() != old_len { - self.0.notification.signal(()); + self.notification.signal(()); } Ok(()) } + + fn key( + &self, + name: &str, + service: &str, + protocol: &str, + port: u16, + ) -> heapless::String<64> { + let mut key = heapless::String::new(); + + write!(&mut key, "{name}.{service}.{protocol}.{port}").unwrap(); + + key + } } - pub struct MdnsRunner<'a, 'b>(&'a Mdns<'b>); + pub struct MdnsRunner<'a>(&'a Mdns<'a>); - impl<'a, 'b> MdnsRunner<'a, 'b> { - pub async fn run_udp( - &mut self, - tx_buf: &mut MdnsTxBuf, - rx_buf: &mut MdnsRxBuf, - ) -> Result<(), Error> { - let udp = UdpListener::new(SocketAddr::new(IP_BIND_ADDR.0, IP_BIND_ADDR.1)).await?; + impl<'a> MdnsRunner<'a> { + pub const fn new(mdns: &'a Mdns<'a>) -> Self { + Self(mdns) + } + + pub async fn run_udp(&mut self) -> Result<(), Error> { + let mut tx_buf = MdnsTxBuf::uninit(); + let mut rx_buf = MdnsRxBuf::uninit(); + + let tx_buf = &mut tx_buf; + let rx_buf = &mut rx_buf; let tx_pipe = Pipe::new(unsafe { tx_buf.assume_init_mut() }); let rx_pipe = Pipe::new(unsafe { rx_buf.assume_init_mut() }); let tx_pipe = &tx_pipe; let rx_pipe = &rx_pipe; + + let udp = UdpListener::new(SocketAddr::new(IP_BIND_ADDR.0, IP_BIND_ADDR.1)).await?; let udp = &udp; let mut tx = pin!(async move { @@ -584,7 +615,7 @@ pub mod builtin { select3(&mut tx, &mut rx, &mut run).await.unwrap() } - pub async fn run(&mut self, tx_pipe: &Pipe<'_>, rx_pipe: &Pipe<'_>) -> Result<(), Error> { + pub async fn run(&self, tx_pipe: &Pipe<'_>, rx_pipe: &Pipe<'_>) -> Result<(), Error> { let mut broadcast = pin!(self.broadcast(tx_pipe)); let mut respond = pin!(self.respond(rx_pipe, tx_pipe)); @@ -667,7 +698,7 @@ pub mod builtin { } } - impl<'a, 'b> super::Mdns for MdnsApi<'a, 'b> { + impl<'a, 'b> super::Mdns for Mdns<'a> { fn add( &self, name: &str, @@ -677,7 +708,7 @@ pub mod builtin { service_subtypes: &[&str], txt_kvs: &[(&str, &str)], ) -> Result<(), Error> { - MdnsApi::add( + Mdns::add( self, name, service, @@ -695,7 +726,7 @@ pub mod builtin { protocol: &str, port: u16, ) -> Result<(), Error> { - MdnsApi::remove(self, name, service, protocol, port) + Mdns::remove(self, name, service, protocol, port) } } } @@ -705,28 +736,34 @@ pub mod astro { use core::cell::RefCell; use std::collections::HashMap; - use super::Mdns; - use crate::error::{Error, ErrorCode}; + use crate::{ + error::{Error, ErrorCode}, + transport::pipe::Pipe, + }; use astro_dnssd::{DNSServiceBuilder, RegisteredDnsService}; use log::info; #[derive(Debug, Clone, Eq, PartialEq, Hash)] - pub struct ServiceId { + struct ServiceId { name: String, service: String, protocol: String, port: u16, } - pub struct AstroMdns { + pub struct Mdns { services: RefCell>, } - impl AstroMdns { - pub fn new() -> Result { - Ok(Self { + impl Mdns { + pub fn new(_id: u16, _hostname: &str, _ip: [u8; 4], _ipv6: Option<[u8; 16]>) -> Self { + Self::native_new() + } + + pub fn native_new() -> Self { + Self { services: RefCell::new(HashMap::new()), - }) + } } pub fn add( @@ -798,7 +835,23 @@ pub mod astro { } } - impl Mdns for AstroMdns { + pub struct MdnsRunner<'a>(&'a Mdns); + + impl<'a> MdnsRunner<'a> { + pub const fn new(mdns: &'a Mdns) -> Self { + Self(mdns) + } + + pub async fn run_udp(&mut self) -> Result<(), Error> { + core::future::pending::>().await + } + + pub async fn run(&self, _tx_pipe: &Pipe<'_>, _rx_pipe: &Pipe<'_>) -> Result<(), Error> { + core::future::pending::>().await + } + } + + impl super::Mdns for Mdns { fn add( &self, name: &str, @@ -808,7 +861,7 @@ pub mod astro { service_subtypes: &[&str], txt_kvs: &[(&str, &str)], ) -> Result<(), Error> { - AstroMdns::add( + Mdns::add( self, name, service, @@ -826,7 +879,7 @@ pub mod astro { protocol: &str, port: u16, ) -> Result<(), Error> { - AstroMdns::remove(self, name, service, protocol, port) + Mdns::remove(self, name, service, protocol, port) } } } From b882aad1ffb03dfa5f5961ce2a37e5ff7d9ff839 Mon Sep 17 00:00:00 2001 From: ivmarkov Date: Sat, 10 Jun 2023 18:51:34 +0000 Subject: [PATCH 54/72] Clippy --- matter/src/mdns.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/matter/src/mdns.rs b/matter/src/mdns.rs index 5c831eef..897ec32e 100644 --- a/matter/src/mdns.rs +++ b/matter/src/mdns.rs @@ -698,7 +698,7 @@ pub mod builtin { } } - impl<'a, 'b> super::Mdns for Mdns<'a> { + impl<'a> super::Mdns for Mdns<'a> { fn add( &self, name: &str, From 488ef5b9f0decf5104247aec15af633d9bd9abdc Mon Sep 17 00:00:00 2001 From: ivmarkov Date: Mon, 12 Jun 2023 09:47:20 +0000 Subject: [PATCH 55/72] Proper mDNS responder --- examples/onoff_light/src/main.rs | 22 +- matter/Cargo.toml | 2 +- matter/src/core.rs | 18 +- matter/src/data_model/root_endpoint.rs | 10 +- .../src/data_model/sdm/admin_commissioning.rs | 16 +- matter/src/data_model/sdm/noc.rs | 12 +- matter/src/fabric.rs | 16 +- matter/src/mdns.rs | 826 ++---------------- matter/src/mdns/astro.rs | 106 +++ matter/src/mdns/builtin.rs | 317 +++++++ matter/src/mdns/proto.rs | 508 +++++++++++ matter/src/secure_channel/core.rs | 8 +- matter/src/secure_channel/pake.rs | 24 +- matter/src/transport/udp.rs | 19 +- 14 files changed, 1064 insertions(+), 840 deletions(-) create mode 100644 matter/src/mdns/astro.rs create mode 100644 matter/src/mdns/builtin.rs create mode 100644 matter/src/mdns/proto.rs diff --git a/examples/onoff_light/src/main.rs b/examples/onoff_light/src/main.rs index 1bc99441..e26f20f1 100644 --- a/examples/onoff_light/src/main.rs +++ b/examples/onoff_light/src/main.rs @@ -69,6 +69,16 @@ fn run() -> Result<(), Error> { core::mem::size_of::(), ); + let dev_det = BasicInfoConfig { + vid: 0xFFF1, + pid: 0x8000, + hw_ver: 2, + sw_ver: 1, + sw_ver_str: "1", + serial_no: "aabbccdd", + device_name: "OnOff Light", + }; + let (ipv4_addr, ipv6_addr) = initialize_network()?; let mdns = DefaultMdns::new( @@ -76,6 +86,8 @@ fn run() -> Result<(), Error> { "matter-demo", ipv4_addr.octets(), Some(ipv6_addr.octets()), + &dev_det, + matter::MATTER_PORT, ); let mut mdns_runner = DefaultMdnsRunner::new(&mdns); @@ -84,15 +96,7 @@ fn run() -> Result<(), Error> { let matter = Matter::new_default( // vid/pid should match those in the DAC - &BasicInfoConfig { - vid: 0xFFF1, - pid: 0x8000, - hw_ver: 2, - sw_ver: 1, - sw_ver_str: "1", - serial_no: "aabbccdd", - device_name: "OnOff Light", - }, + &dev_det, &dev_att, &mdns, matter::MATTER_PORT, diff --git a/matter/Cargo.toml b/matter/Cargo.toml index 410b30ce..2f859b7c 100644 --- a/matter/Cargo.toml +++ b/matter/Cargo.toml @@ -46,7 +46,7 @@ embassy-futures = "0.1" embassy-time = { version = "0.1.1", features = ["generic-queue-8"] } embassy-sync = "0.2" critical-section = "1.1.1" -domain = { version = "0.7.2", default_features = false } +domain = { version = "0.7.2", default_features = false, features = ["heapless"] } # STD-only dependencies rand = { version = "0.8.5", optional = true } diff --git a/matter/src/core.rs b/matter/src/core.rs index c6c5dc10..35c86771 100644 --- a/matter/src/core.rs +++ b/matter/src/core.rs @@ -25,7 +25,7 @@ use crate::{ }, error::*, fabric::FabricMgr, - mdns::{Mdns, MdnsMgr}, + mdns::Mdns, pairing::{print_pairing_code_and_qr, DiscoveryCapabilities}, secure_channel::{pake::PaseMgr, spake2p::VerifierData}, utils::{epoch::Epoch, rand::Rand}, @@ -48,7 +48,7 @@ pub struct Matter<'a> { pub acl_mgr: RefCell, pub pase_mgr: RefCell, pub failsafe: RefCell, - pub mdns_mgr: MdnsMgr<'a>, + pub mdns: &'a dyn Mdns, pub epoch: Epoch, pub rand: Rand, pub dev_det: &'a BasicInfoConfig<'a>, @@ -91,7 +91,7 @@ impl<'a> Matter<'a> { acl_mgr: RefCell::new(AclMgr::new()), pase_mgr: RefCell::new(PaseMgr::new(epoch, rand)), failsafe: RefCell::new(FailSafe::new()), - mdns_mgr: MdnsMgr::new(dev_det.vid, dev_det.pid, dev_det.device_name, port, mdns), + mdns, epoch, rand, dev_det, @@ -113,7 +113,7 @@ impl<'a> Matter<'a> { } pub fn load_fabrics(&self, data: &[u8]) -> Result<(), Error> { - self.fabric_mgr.borrow_mut().load(data, &self.mdns_mgr) + self.fabric_mgr.borrow_mut().load(data, self.mdns) } pub fn load_acls(&self, data: &[u8]) -> Result<(), Error> { @@ -149,7 +149,7 @@ impl<'a> Matter<'a> { self.pase_mgr.borrow_mut().enable_pase_session( dev_comm.verifier, dev_comm.discriminator, - &self.mdns_mgr, + self.mdns, )?; Ok(true) @@ -183,12 +183,6 @@ impl<'a> Borrow> for Matter<'a> { } } -impl<'a> Borrow> for Matter<'a> { - fn borrow(&self) -> &MdnsMgr<'a> { - &self.mdns_mgr - } -} - impl<'a> Borrow> for Matter<'a> { fn borrow(&self) -> &BasicInfoConfig<'a> { self.dev_det @@ -203,7 +197,7 @@ impl<'a> Borrow for Matter<'a> { impl<'a> Borrow for Matter<'a> { fn borrow(&self) -> &(dyn Mdns + 'a) { - self.mdns_mgr.mdns + self.mdns } } diff --git a/matter/src/data_model/root_endpoint.rs b/matter/src/data_model/root_endpoint.rs index 1bc22fe2..78b8cfb9 100644 --- a/matter/src/data_model/root_endpoint.rs +++ b/matter/src/data_model/root_endpoint.rs @@ -4,7 +4,7 @@ use crate::{ acl::AclMgr, fabric::FabricMgr, handler_chain_type, - mdns::MdnsMgr, + mdns::Mdns, secure_channel::pake::PaseMgr, utils::{epoch::Epoch, rand::Rand}, }; @@ -62,7 +62,7 @@ where + Borrow> + Borrow> + Borrow> - + Borrow> + + Borrow + Borrow + Borrow + 'a, @@ -90,7 +90,7 @@ pub fn wrap<'a>( fabric: &'a RefCell, acl: &'a RefCell, failsafe: &'a RefCell, - mdns_mgr: &'a MdnsMgr<'a>, + mdns: &'a dyn Mdns, epoch: Epoch, rand: Rand, ) -> RootEndpointHandler<'a> { @@ -103,12 +103,12 @@ pub fn wrap<'a>( .chain( endpoint_id, noc::ID, - NocCluster::new(dev_att, fabric, acl, failsafe, mdns_mgr, epoch, rand), + NocCluster::new(dev_att, fabric, acl, failsafe, mdns, epoch, rand), ) .chain( endpoint_id, admin_commissioning::ID, - AdminCommCluster::new(pase, mdns_mgr, rand), + AdminCommCluster::new(pase, mdns, rand), ) .chain(endpoint_id, nw_commissioning::ID, NwCommCluster::new(rand)) .chain( diff --git a/matter/src/data_model/sdm/admin_commissioning.rs b/matter/src/data_model/sdm/admin_commissioning.rs index 93643115..15c803f4 100644 --- a/matter/src/data_model/sdm/admin_commissioning.rs +++ b/matter/src/data_model/sdm/admin_commissioning.rs @@ -20,7 +20,7 @@ use core::convert::TryInto; use crate::data_model::objects::*; use crate::interaction_model::core::Transaction; -use crate::mdns::MdnsMgr; +use crate::mdns::Mdns; use crate::secure_channel::pake::PaseMgr; use crate::secure_channel::spake2p::VerifierData; use crate::tlv::{FromTLV, Nullable, OctetStr, TLVElement}; @@ -102,15 +102,15 @@ pub struct OpenCommWindowReq<'a> { pub struct AdminCommCluster<'a> { data_ver: Dataver, pase_mgr: &'a RefCell, - mdns_mgr: &'a MdnsMgr<'a>, + mdns: &'a dyn Mdns, } impl<'a> AdminCommCluster<'a> { - pub fn new(pase_mgr: &'a RefCell, mdns_mgr: &'a MdnsMgr<'a>, rand: Rand) -> Self { + pub fn new(pase_mgr: &'a RefCell, mdns: &'a dyn Mdns, rand: Rand) -> Self { Self { data_ver: Dataver::new(rand), pase_mgr, - mdns_mgr, + mdns, } } @@ -152,11 +152,9 @@ impl<'a> AdminCommCluster<'a> { cmd_enter!("Open Commissioning Window"); let req = OpenCommWindowReq::from_tlv(data)?; let verifier = VerifierData::new(req.verifier.0, req.iterations, req.salt.0); - self.pase_mgr.borrow_mut().enable_pase_session( - verifier, - req.discriminator, - self.mdns_mgr, - )?; + self.pase_mgr + .borrow_mut() + .enable_pase_session(verifier, req.discriminator, self.mdns)?; Ok(()) } diff --git a/matter/src/data_model/sdm/noc.rs b/matter/src/data_model/sdm/noc.rs index f347b13c..7fb1e37b 100644 --- a/matter/src/data_model/sdm/noc.rs +++ b/matter/src/data_model/sdm/noc.rs @@ -25,7 +25,7 @@ use crate::data_model::objects::*; use crate::data_model::sdm::dev_att; use crate::fabric::{Fabric, FabricMgr, MAX_SUPPORTED_FABRICS}; use crate::interaction_model::core::Transaction; -use crate::mdns::MdnsMgr; +use crate::mdns::Mdns; use crate::tlv::{FromTLV, OctetStr, TLVElement, TLVWriter, TagType, ToTLV, UtfStr}; use crate::transport::session::SessionMode; use crate::utils::epoch::Epoch; @@ -222,7 +222,7 @@ pub struct NocCluster<'a> { fabric_mgr: &'a RefCell, acl_mgr: &'a RefCell, failsafe: &'a RefCell, - mdns_mgr: &'a MdnsMgr<'a>, + mdns: &'a dyn Mdns, } impl<'a> NocCluster<'a> { @@ -231,7 +231,7 @@ impl<'a> NocCluster<'a> { fabric_mgr: &'a RefCell, acl_mgr: &'a RefCell, failsafe: &'a RefCell, - mdns_mgr: &'a MdnsMgr<'a>, + mdns: &'a dyn Mdns, epoch: Epoch, rand: Rand, ) -> Self { @@ -243,7 +243,7 @@ impl<'a> NocCluster<'a> { fabric_mgr, acl_mgr, failsafe, - mdns_mgr, + mdns, } } @@ -383,7 +383,7 @@ impl<'a> NocCluster<'a> { let fab_idx = self .fabric_mgr .borrow_mut() - .add(fabric, self.mdns_mgr) + .add(fabric, self.mdns) .map_err(|_| NocStatus::TableFull)?; self.add_acl(fab_idx, r.case_admin_subject)?; @@ -455,7 +455,7 @@ impl<'a> NocCluster<'a> { if self .fabric_mgr .borrow_mut() - .remove(req.fab_idx, self.mdns_mgr) + .remove(req.fab_idx, self.mdns) .is_ok() { let _ = self.acl_mgr.borrow_mut().delete_for_fabric(req.fab_idx); diff --git a/matter/src/fabric.rs b/matter/src/fabric.rs index 04369ca0..89594070 100644 --- a/matter/src/fabric.rs +++ b/matter/src/fabric.rs @@ -26,7 +26,7 @@ use crate::{ crypto::{self, hkdf_sha256, HmacSha256, KeyPair}, error::{Error, ErrorCode}, group_keys::KeySet, - mdns::{MdnsMgr, ServiceMode}, + mdns::{Mdns, ServiceMode}, tlv::{self, FromTLV, OctetStr, TLVList, TLVWriter, TagType, ToTLV, UtfStr}, utils::writebuf::WriteBuf, }; @@ -200,9 +200,9 @@ impl FabricMgr { } } - pub fn load(&mut self, data: &[u8], mdns_mgr: &MdnsMgr) -> Result<(), Error> { + pub fn load(&mut self, data: &[u8], mdns: &dyn Mdns) -> Result<(), Error> { for fabric in self.fabrics.iter().flatten() { - mdns_mgr.unpublish_service(&fabric.mdns_service_name, ServiceMode::Commissioned)?; + mdns.remove(&fabric.mdns_service_name)?; } let root = TLVList::new(data).iter().next().ok_or(ErrorCode::Invalid)?; @@ -210,7 +210,7 @@ impl FabricMgr { tlv::from_tlv(&mut self.fabrics, &root)?; for fabric in self.fabrics.iter().flatten() { - mdns_mgr.publish_service(&fabric.mdns_service_name, ServiceMode::Commissioned)?; + mdns.add(&fabric.mdns_service_name, ServiceMode::Commissioned)?; } self.changed = false; @@ -241,11 +241,11 @@ impl FabricMgr { self.changed } - pub fn add(&mut self, f: Fabric, mdns_mgr: &MdnsMgr) -> Result { + pub fn add(&mut self, f: Fabric, mdns: &dyn Mdns) -> Result { let slot = self.fabrics.iter().position(|x| x.is_none()); if slot.is_some() || self.fabrics.len() < MAX_SUPPORTED_FABRICS { - mdns_mgr.publish_service(&f.mdns_service_name, ServiceMode::Commissioned)?; + mdns.add(&f.mdns_service_name, ServiceMode::Commissioned)?; self.changed = true; if let Some(index) = slot { @@ -265,10 +265,10 @@ impl FabricMgr { } } - pub fn remove(&mut self, fab_idx: u8, mdns_mgr: &MdnsMgr) -> Result<(), Error> { + pub fn remove(&mut self, fab_idx: u8, mdns: &dyn Mdns) -> Result<(), Error> { if fab_idx > 0 && fab_idx as usize <= self.fabrics.len() { if let Some(f) = self.fabrics[(fab_idx - 1) as usize].take() { - mdns_mgr.unpublish_service(&f.mdns_service_name, ServiceMode::Commissioned)?; + mdns.remove(&f.mdns_service_name)?; self.changed = true; Ok(()) } else { diff --git a/matter/src/mdns.rs b/matter/src/mdns.rs index 897ec32e..d07ba107 100644 --- a/matter/src/mdns.rs +++ b/matter/src/mdns.rs @@ -17,40 +17,28 @@ use core::fmt::Write; -use crate::error::Error; +use crate::{data_model::cluster_basic_information::BasicInfoConfig, error::Error}; + +#[cfg(all(feature = "std", feature = "astro-dnssd"))] +pub mod astro; +pub mod builtin; +pub mod proto; pub trait Mdns { - fn add( - &self, - name: &str, - service: &str, - protocol: &str, - port: u16, - service_subtypes: &[&str], - txt_kvs: &[(&str, &str)], - ) -> Result<(), Error>; - - fn remove(&self, name: &str, service: &str, protocol: &str, port: u16) -> Result<(), Error>; + fn add(&self, service: &str, mode: ServiceMode) -> Result<(), Error>; + fn remove(&self, service: &str) -> Result<(), Error>; } impl Mdns for &mut T where T: Mdns, { - fn add( - &self, - name: &str, - service: &str, - protocol: &str, - port: u16, - service_subtypes: &[&str], - txt_kvs: &[(&str, &str)], - ) -> Result<(), Error> { - (**self).add(name, service, protocol, port, service_subtypes, txt_kvs) + fn add(&self, service: &str, mode: ServiceMode) -> Result<(), Error> { + (**self).add(service, mode) } - fn remove(&self, name: &str, service: &str, protocol: &str, port: u16) -> Result<(), Error> { - (**self).remove(name, service, protocol, port) + fn remove(&self, service: &str) -> Result<(), Error> { + (**self).remove(service) } } @@ -58,25 +46,17 @@ impl Mdns for &T where T: Mdns, { - fn add( - &self, - name: &str, - service: &str, - protocol: &str, - port: u16, - service_subtypes: &[&str], - txt_kvs: &[(&str, &str)], - ) -> Result<(), Error> { - (**self).add(name, service, protocol, port, service_subtypes, txt_kvs) + fn add(&self, service: &str, mode: ServiceMode) -> Result<(), Error> { + (**self).add(service, mode) } - fn remove(&self, name: &str, service: &str, protocol: &str, port: u16) -> Result<(), Error> { - (**self).remove(name, service, protocol, port) + fn remove(&self, service: &str) -> Result<(), Error> { + (**self).remove(service) } } #[cfg(all(feature = "std", feature = "astro-dnssd"))] -pub type DefaultMdns = astro::Mdns; +pub type DefaultMdns<'a> = astro::Mdns<'a>; #[cfg(all(feature = "std", feature = "astro-dnssd"))] pub type DefaultMdnsRunner<'a> = astro::MdnsRunner<'a>; @@ -90,29 +70,18 @@ pub type DefaultMdnsRunner<'a> = builtin::MdnsRunner<'a>; pub struct DummyMdns; impl Mdns for DummyMdns { - fn add( - &self, - _name: &str, - _service: &str, - _protocol: &str, - _port: u16, - _service_subtypes: &[&str], - _txt_kvs: &[(&str, &str)], - ) -> Result<(), Error> { + fn add(&self, _service: &str, _mode: ServiceMode) -> Result<(), Error> { Ok(()) } - fn remove( - &self, - _name: &str, - _service: &str, - _protocol: &str, - _port: u16, - ) -> Result<(), Error> { + fn remove(&self, _service: &str) -> Result<(), Error> { Ok(()) } } +pub type Service<'a> = proto::Service<'a>; + +#[derive(Debug, Clone, Eq, PartialEq)] pub enum ServiceMode { /// The commissioned state Commissioned, @@ -120,56 +89,31 @@ pub enum ServiceMode { Commissionable(u16), } -/// The mDNS service handler -pub struct MdnsMgr<'a> { - /// Vendor ID - vid: u16, - /// Product ID - pid: u16, - /// Device name - device_name: &'a str, - /// Matter port - matter_port: u16, - /// mDns service - pub(crate) mdns: &'a dyn Mdns, -} - -impl<'a> MdnsMgr<'a> { - #[inline(always)] - pub fn new( - vid: u16, - pid: u16, - device_name: &'a str, +impl ServiceMode { + pub fn service FnOnce(&Service<'a>) -> Result>( + &self, + dev_att: &BasicInfoConfig, matter_port: u16, - mdns: &'a dyn Mdns, - ) -> Self { - Self { - vid, - pid, - device_name, - matter_port, - mdns, - } - } - - /// Publish an mDNS service - /// name - is the service name (comma separated subtypes may follow) - /// mode - the current service mode - #[allow(clippy::needless_pass_by_value)] - pub fn publish_service(&self, name: &str, mode: ServiceMode) -> Result<(), Error> { - match mode { - ServiceMode::Commissioned => { - self.mdns - .add(name, "_matter", "_tcp", self.matter_port, &[], &[]) - } + name: &str, + f: F, + ) -> Result { + match self { + Self::Commissioned => f(&Service { + name, + service: "_matter", + protocol: "_tcp", + port: matter_port, + service_subtypes: &[], + txt_kvs: &[], + }), ServiceMode::Commissionable(discriminator) => { - let discriminator_str = Self::get_discriminator_str(discriminator); - let vp = self.get_vp(); + let discriminator_str = Self::get_discriminator_str(*discriminator); + let vp = Self::get_vp(dev_att.vid, dev_att.pid); - let txt_kvs = [ + let txt_kvs = &[ ("D", discriminator_str.as_str()), ("CM", "1"), - ("DN", self.device_name), + ("DN", dev_att.device_name), ("VP", &vp), ("SII", "5000"), /* Sleepy Idle Interval */ ("SAI", "300"), /* Sleepy Active Interval */ @@ -177,40 +121,29 @@ impl<'a> MdnsMgr<'a> { ("PI", ""), /* Pairing Instruction */ ]; - self.mdns.add( + f(&Service { name, - "_matterc", - "_udp", - self.matter_port, - &[ - &self.get_long_service_subtype(discriminator), - &self.get_short_service_type(discriminator), + service: "_matterc", + protocol: "_udp", + port: matter_port, + service_subtypes: &[ + &Self::get_long_service_subtype(*discriminator), + &Self::get_short_service_type(*discriminator), ], - &txt_kvs, - ) + txt_kvs, + }) } } } - pub fn unpublish_service(&self, name: &str, mode: ServiceMode) -> Result<(), Error> { - match mode { - ServiceMode::Commissioned => { - self.mdns.remove(name, "_matter", "_tcp", self.matter_port) - } - ServiceMode::Commissionable(_) => { - self.mdns.remove(name, "_matterc", "_udp", self.matter_port) - } - } - } - - fn get_long_service_subtype(&self, discriminator: u16) -> heapless::String<32> { + fn get_long_service_subtype(discriminator: u16) -> heapless::String<32> { let mut serv_type = heapless::String::new(); write!(&mut serv_type, "_L{}", discriminator).unwrap(); serv_type } - fn get_short_service_type(&self, discriminator: u16) -> heapless::String<32> { + fn get_short_service_type(discriminator: u16) -> heapless::String<32> { let short = Self::compute_short_discriminator(discriminator); let mut serv_type = heapless::String::new(); @@ -223,10 +156,10 @@ impl<'a> MdnsMgr<'a> { discriminator.into() } - fn get_vp(&self) -> heapless::String<11> { + fn get_vp(vid: u16, pid: u16) -> heapless::String<11> { let mut vp = heapless::String::new(); - write!(&mut vp, "{}+{}", self.vid, self.pid).unwrap(); + write!(&mut vp, "{}+{}", vid, pid).unwrap(); vp } @@ -239,651 +172,6 @@ impl<'a> MdnsMgr<'a> { } } -pub mod builtin { - use core::cell::RefCell; - use core::fmt::Write; - use core::mem::MaybeUninit; - use core::pin::pin; - use core::str::FromStr; - - use domain::base::header::Flags; - use domain::base::iana::Class; - use domain::base::octets::{Octets256, Octets64, OctetsBuilder}; - use domain::base::{Dname, MessageBuilder, Record, ShortBuf}; - use domain::rdata::{Aaaa, Ptr, Srv, Txt, A}; - use embassy_futures::select::{select, select3}; - use embassy_time::{Duration, Timer}; - use log::info; - - use crate::error::{Error, ErrorCode}; - use crate::transport::network::{Address, IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr}; - use crate::transport::packet::{MAX_RX_BUF_SIZE, MAX_TX_BUF_SIZE}; - use crate::transport::pipe::{Chunk, Pipe}; - use crate::transport::udp::UdpListener; - use crate::utils::select::{EitherUnwrap, Notification}; - - const IP_BROADCAST_ADDRS: [(IpAddr, u16); 2] = [ - (IpAddr::V4(Ipv4Addr::new(224, 0, 0, 251)), 5353), - ( - IpAddr::V6(Ipv6Addr::new(0xff02, 0, 0, 0, 0, 0, 0, 0x00fb)), - 5353, - ), - ]; - - const IP_BIND_ADDR: (IpAddr, u16) = (IpAddr::V6(Ipv6Addr::UNSPECIFIED), 5353); - - type MdnsTxBuf = MaybeUninit<[u8; MAX_TX_BUF_SIZE]>; - type MdnsRxBuf = MaybeUninit<[u8; MAX_RX_BUF_SIZE]>; - - #[allow(clippy::too_many_arguments)] - pub fn create_record( - id: u16, - hostname: &str, - ip: [u8; 4], - ipv6: Option<[u8; 16]>, - - ttl_sec: u32, - - name: &str, - service: &str, - protocol: &str, - port: u16, - service_subtypes: &[&str], - txt_kvs: &[(&str, &str)], - - buffer: &mut [u8], - ) -> Result { - let target = domain::base::octets::Octets2048::new(); - let message = MessageBuilder::from_target(target)?; - - let mut message = message.answer(); - - let mut ptr_str = heapless::String::<40>::new(); - write!(ptr_str, "{}.{}.local", service, protocol).unwrap(); - - let mut dname = heapless::String::<60>::new(); - write!(dname, "{}.{}.{}.local", name, service, protocol).unwrap(); - - let mut hname = heapless::String::<40>::new(); - write!(hname, "{}.local", hostname).unwrap(); - - let ptr: Dname = Dname::from_str(&ptr_str).unwrap(); - let record: Record, Ptr<_>> = Record::new( - Dname::from_str("_services._dns-sd._udp.local").unwrap(), - Class::In, - ttl_sec, - Ptr::new(ptr), - ); - message.push(record)?; - - let t: Dname = Dname::from_str(&dname).unwrap(); - let record: Record, Ptr<_>> = Record::new( - Dname::from_str(&ptr_str).unwrap(), - Class::In, - ttl_sec, - Ptr::new(t), - ); - message.push(record)?; - - for sub_srv in service_subtypes { - let mut ptr_str = heapless::String::<40>::new(); - write!(ptr_str, "{}._sub.{}.{}.local", sub_srv, service, protocol).unwrap(); - - let ptr: Dname = Dname::from_str(&ptr_str).unwrap(); - let record: Record, Ptr<_>> = Record::new( - Dname::from_str("_services._dns-sd._udp.local").unwrap(), - Class::In, - ttl_sec, - Ptr::new(ptr), - ); - message.push(record)?; - - let t: Dname = Dname::from_str(&dname).unwrap(); - let record: Record, Ptr<_>> = Record::new( - Dname::from_str(&ptr_str).unwrap(), - Class::In, - ttl_sec, - Ptr::new(t), - ); - message.push(record)?; - } - - let target: Dname = Dname::from_str(&hname).unwrap(); - let record: Record, Srv<_>> = Record::new( - Dname::from_str(&dname).unwrap(), - Class::In, - ttl_sec, - Srv::new(0, 0, port, target), - ); - message.push(record)?; - - // only way I found to create multiple parts in a Txt - // each slice is the length and then the data - let mut octets = Octets256::new(); - //octets.append_slice(&[1u8, b'X']).unwrap(); - //octets.append_slice(&[2u8, b'A', b'B']).unwrap(); - //octets.append_slice(&[0u8]).unwrap(); - for (k, v) in txt_kvs { - octets - .append_slice(&[(k.len() + v.len() + 1) as u8]) - .unwrap(); - octets.append_slice(k.as_bytes()).unwrap(); - octets.append_slice(&[b'=']).unwrap(); - octets.append_slice(v.as_bytes()).unwrap(); - } - - let txt = Txt::from_octets(&mut octets).unwrap(); - - let record: Record, Txt<_>> = - Record::new(Dname::from_str(&dname).unwrap(), Class::In, ttl_sec, txt); - message.push(record)?; - - let record: Record, A> = Record::new( - Dname::from_str(&hname).unwrap(), - Class::In, - ttl_sec, - A::from_octets(ip[0], ip[1], ip[2], ip[3]), - ); - message.push(record)?; - - if let Some(ipv6) = ipv6 { - let record: Record, Aaaa> = Record::new( - Dname::from_str(&hname).unwrap(), - Class::In, - ttl_sec, - Aaaa::new(ipv6.into()), - ); - message.push(record)?; - } - - let headerb = message.header_mut(); - headerb.set_id(id); - headerb.set_opcode(domain::base::iana::Opcode::Query); - headerb.set_rcode(domain::base::iana::Rcode::NoError); - - let mut flags = Flags::new(); - flags.qr = true; - flags.aa = true; - headerb.set_flags(flags); - - let target = message.finish(); - - buffer[..target.len()].copy_from_slice(target.as_ref()); - - Ok(target.len()) - } - - #[derive(Debug, Clone)] - struct MdnsEntry { - key: heapless::String<64>, - record: heapless::Vec, - } - - impl MdnsEntry { - #[inline(always)] - const fn new(key: heapless::String<64>) -> Self { - Self { - key, - record: heapless::Vec::new(), - } - } - } - - pub struct Mdns<'a> { - id: u16, - hostname: &'a str, - ip: [u8; 4], - ipv6: Option<[u8; 16]>, - entries: RefCell>, - notification: Notification, - } - - impl<'a> Mdns<'a> { - #[inline(always)] - pub const fn new(id: u16, hostname: &'a str, ip: [u8; 4], ipv6: Option<[u8; 16]>) -> Self { - Self { - id, - hostname, - ip, - ipv6, - entries: RefCell::new(heapless::Vec::new()), - notification: Notification::new(), - } - } - - pub fn add( - &self, - name: &str, - service: &str, - protocol: &str, - port: u16, - service_subtypes: &[&str], - txt_kvs: &[(&str, &str)], - ) -> Result<(), Error> { - info!( - "Registering mDNS service {}/{}.{} [{:?}]/{}, keys [{:?}]", - name, service, protocol, service_subtypes, port, txt_kvs - ); - - let key = self.key(name, service, protocol, port); - - let mut entries = self.entries.borrow_mut(); - - entries.retain(|entry| entry.key != key); - entries - .push(MdnsEntry::new(key)) - .map_err(|_| ErrorCode::NoSpace)?; - - let entry = entries.iter_mut().last().unwrap(); - entry - .record - .resize(1024, 0) - .map_err(|_| ErrorCode::NoSpace) - .unwrap(); - - match create_record( - self.id, - self.hostname, - self.ip, - self.ipv6, - 60, /*ttl_sec*/ - name, - service, - protocol, - port, - service_subtypes, - txt_kvs, - &mut entry.record, - ) { - Ok(len) => entry.record.truncate(len), - Err(_) => { - entries.pop(); - Err(ErrorCode::NoSpace)?; - } - } - - self.notification.signal(()); - - Ok(()) - } - - pub fn remove( - &self, - name: &str, - service: &str, - protocol: &str, - port: u16, - ) -> Result<(), Error> { - info!( - "Deregistering mDNS service {}/{}.{}/{}", - name, service, protocol, port - ); - - let key = self.key(name, service, protocol, port); - - let mut entries = self.entries.borrow_mut(); - - let old_len = entries.len(); - - entries.retain(|entry| entry.key != key); - - if entries.len() != old_len { - self.notification.signal(()); - } - - Ok(()) - } - - fn key( - &self, - name: &str, - service: &str, - protocol: &str, - port: u16, - ) -> heapless::String<64> { - let mut key = heapless::String::new(); - - write!(&mut key, "{name}.{service}.{protocol}.{port}").unwrap(); - - key - } - } - - pub struct MdnsRunner<'a>(&'a Mdns<'a>); - - impl<'a> MdnsRunner<'a> { - pub const fn new(mdns: &'a Mdns<'a>) -> Self { - Self(mdns) - } - - pub async fn run_udp(&mut self) -> Result<(), Error> { - let mut tx_buf = MdnsTxBuf::uninit(); - let mut rx_buf = MdnsRxBuf::uninit(); - - let tx_buf = &mut tx_buf; - let rx_buf = &mut rx_buf; - - let tx_pipe = Pipe::new(unsafe { tx_buf.assume_init_mut() }); - let rx_pipe = Pipe::new(unsafe { rx_buf.assume_init_mut() }); - - let tx_pipe = &tx_pipe; - let rx_pipe = &rx_pipe; - - let udp = UdpListener::new(SocketAddr::new(IP_BIND_ADDR.0, IP_BIND_ADDR.1)).await?; - let udp = &udp; - - let mut tx = pin!(async move { - loop { - { - let mut data = tx_pipe.data.lock().await; - - if let Some(chunk) = data.chunk { - udp.send(chunk.addr.unwrap_udp(), &data.buf[chunk.start..chunk.end]) - .await?; - data.chunk = None; - tx_pipe.data_consumed_notification.signal(()); - } - } - - tx_pipe.data_supplied_notification.wait().await; - } - }); - - let mut rx = pin!(async move { - loop { - { - let mut data = rx_pipe.data.lock().await; - - if data.chunk.is_none() { - let (len, addr) = udp.recv(data.buf).await?; - - data.chunk = Some(Chunk { - start: 0, - end: len, - addr: Address::Udp(addr), - }); - rx_pipe.data_supplied_notification.signal(()); - } - } - - rx_pipe.data_consumed_notification.wait().await; - } - }); - - let mut run = pin!(async move { self.run(tx_pipe, rx_pipe).await }); - - select3(&mut tx, &mut rx, &mut run).await.unwrap() - } - - pub async fn run(&self, tx_pipe: &Pipe<'_>, rx_pipe: &Pipe<'_>) -> Result<(), Error> { - let mut broadcast = pin!(self.broadcast(tx_pipe)); - let mut respond = pin!(self.respond(rx_pipe, tx_pipe)); - - select(&mut broadcast, &mut respond).await.unwrap() - } - - #[allow(clippy::await_holding_refcell_ref)] - async fn broadcast(&self, tx_pipe: &Pipe<'_>) -> Result<(), Error> { - loop { - select( - self.0.notification.wait(), - Timer::after(Duration::from_secs(30)), - ) - .await; - - let mut index = 0; - - 'outer: loop { - for (addr, port) in IP_BROADCAST_ADDRS { - loop { - { - let mut data = tx_pipe.data.lock().await; - - if data.chunk.is_none() { - let entries = self.0.entries.borrow(); - let entry = entries.get(index); - - if let Some(entry) = entry { - info!( - "Broadasting mDNS entry {} on {}:{}", - &entry.key, addr, port - ); - - let len = entry.record.len(); - data.buf[..len].copy_from_slice(&entry.record); - drop(entries); - - data.chunk = Some(Chunk { - start: 0, - end: len, - addr: Address::Udp(SocketAddr::new(addr, port)), - }); - - tx_pipe.data_supplied_notification.signal(()); - } else { - break 'outer; - } - - break; - } - } - - tx_pipe.data_consumed_notification.wait().await; - } - } - - index += 1; - } - } - } - - #[allow(clippy::await_holding_refcell_ref)] - async fn respond(&self, rx_pipe: &Pipe<'_>, _tx_pipe: &Pipe<'_>) -> Result<(), Error> { - loop { - { - let mut data = rx_pipe.data.lock().await; - - if let Some(_chunk) = data.chunk { - // TODO: Process the incoming packed and only answer what we are being queried about - - data.chunk = None; - rx_pipe.data_consumed_notification.signal(()); - - self.0.notification.signal(()); - } - } - - rx_pipe.data_supplied_notification.wait().await; - } - } - } - - impl<'a> super::Mdns for Mdns<'a> { - fn add( - &self, - name: &str, - service: &str, - protocol: &str, - port: u16, - service_subtypes: &[&str], - txt_kvs: &[(&str, &str)], - ) -> Result<(), Error> { - Mdns::add( - self, - name, - service, - protocol, - port, - service_subtypes, - txt_kvs, - ) - } - - fn remove( - &self, - name: &str, - service: &str, - protocol: &str, - port: u16, - ) -> Result<(), Error> { - Mdns::remove(self, name, service, protocol, port) - } - } -} - -#[cfg(all(feature = "std", feature = "astro-dnssd"))] -pub mod astro { - use core::cell::RefCell; - use std::collections::HashMap; - - use crate::{ - error::{Error, ErrorCode}, - transport::pipe::Pipe, - }; - use astro_dnssd::{DNSServiceBuilder, RegisteredDnsService}; - use log::info; - - #[derive(Debug, Clone, Eq, PartialEq, Hash)] - struct ServiceId { - name: String, - service: String, - protocol: String, - port: u16, - } - - pub struct Mdns { - services: RefCell>, - } - - impl Mdns { - pub fn new(_id: u16, _hostname: &str, _ip: [u8; 4], _ipv6: Option<[u8; 16]>) -> Self { - Self::native_new() - } - - pub fn native_new() -> Self { - Self { - services: RefCell::new(HashMap::new()), - } - } - - pub fn add( - &self, - name: &str, - service: &str, - protocol: &str, - port: u16, - service_subtypes: &[&str], - txt_kvs: &[(&str, &str)], - ) -> Result<(), Error> { - info!( - "Registering mDNS service {}/{}.{} [{:?}]/{}", - name, service, protocol, service_subtypes, port - ); - - let _ = self.remove(name, service, protocol, port); - - let composite_service_type = if !service_subtypes.is_empty() { - format!("{}.{},{}", service, protocol, service_subtypes.join(",")) - } else { - format!("{}.{}", service, protocol) - }; - - let mut builder = DNSServiceBuilder::new(&composite_service_type, port).with_name(name); - - for kvs in txt_kvs { - info!("mDNS TXT key {} val {}", kvs.0, kvs.1); - builder = builder.with_key_value(kvs.0.to_string(), kvs.1.to_string()); - } - - let svc = builder.register().map_err(|_| ErrorCode::MdnsError)?; - - self.services.borrow_mut().insert( - ServiceId { - name: name.into(), - service: service.into(), - protocol: protocol.into(), - port, - }, - svc, - ); - - Ok(()) - } - - pub fn remove( - &self, - name: &str, - service: &str, - protocol: &str, - port: u16, - ) -> Result<(), Error> { - let id = ServiceId { - name: name.into(), - service: service.into(), - protocol: protocol.into(), - port, - }; - - if self.services.borrow_mut().remove(&id).is_some() { - info!( - "Deregistering mDNS service {}/{}.{}/{}", - name, service, protocol, port - ); - } - - Ok(()) - } - } - - pub struct MdnsRunner<'a>(&'a Mdns); - - impl<'a> MdnsRunner<'a> { - pub const fn new(mdns: &'a Mdns) -> Self { - Self(mdns) - } - - pub async fn run_udp(&mut self) -> Result<(), Error> { - core::future::pending::>().await - } - - pub async fn run(&self, _tx_pipe: &Pipe<'_>, _rx_pipe: &Pipe<'_>) -> Result<(), Error> { - core::future::pending::>().await - } - } - - impl super::Mdns for Mdns { - fn add( - &self, - name: &str, - service: &str, - protocol: &str, - port: u16, - service_subtypes: &[&str], - txt_kvs: &[(&str, &str)], - ) -> Result<(), Error> { - Mdns::add( - self, - name, - service, - protocol, - port, - service_subtypes, - txt_kvs, - ) - } - - fn remove( - &self, - name: &str, - service: &str, - protocol: &str, - port: u16, - ) -> Result<(), Error> { - Mdns::remove(self, name, service, protocol, port) - } - } -} - #[cfg(test)] mod tests { use super::*; @@ -891,11 +179,11 @@ mod tests { #[test] fn can_compute_short_discriminator() { let discriminator: u16 = 0b0000_1111_0000_0000; - let short = MdnsMgr::compute_short_discriminator(discriminator); + let short = ServiceMode::compute_short_discriminator(discriminator); assert_eq!(short, 0b1111); let discriminator: u16 = 840; - let short = MdnsMgr::compute_short_discriminator(discriminator); + let short = ServiceMode::compute_short_discriminator(discriminator); assert_eq!(short, 3); } } diff --git a/matter/src/mdns/astro.rs b/matter/src/mdns/astro.rs new file mode 100644 index 00000000..12426cb6 --- /dev/null +++ b/matter/src/mdns/astro.rs @@ -0,0 +1,106 @@ +use core::cell::RefCell; +use std::collections::HashMap; + +use crate::{ + data_model::cluster_basic_information::BasicInfoConfig, + error::{Error, ErrorCode}, + transport::pipe::Pipe, +}; +use astro_dnssd::{DNSServiceBuilder, RegisteredDnsService}; +use log::info; + +use super::ServiceMode; + +pub struct Mdns<'a> { + dev_det: &'a BasicInfoConfig<'a>, + matter_port: u16, + services: RefCell>, +} + +impl<'a> Mdns<'a> { + pub fn new( + _id: u16, + _hostname: &str, + _ip: [u8; 4], + _ipv6: Option<[u8; 16]>, + dev_det: &'a BasicInfoConfig<'a>, + matter_port: u16, + ) -> Self { + Self::native_new(dev_det, matter_port) + } + + pub fn native_new(dev_det: &'a BasicInfoConfig<'a>, matter_port: u16) -> Self { + Self { + dev_det, + matter_port, + services: RefCell::new(HashMap::new()), + } + } + + pub fn add(&self, name: &str, mode: ServiceMode) -> Result<(), Error> { + info!("Registering mDNS service {}/{:?}", name, mode); + + let _ = self.remove(name); + + mode.service(self.dev_det, self.matter_port, name, |service| { + let composite_service_type = if !service.service_subtypes.is_empty() { + format!( + "{}.{},{}", + service.service, + service.protocol, + service.service_subtypes.join(",") + ) + } else { + format!("{}.{}", service.service, service.protocol) + }; + + let mut builder = DNSServiceBuilder::new(&composite_service_type, service.port) + .with_name(service.name); + + for kvs in service.txt_kvs { + info!("mDNS TXT key {} val {}", kvs.0, kvs.1); + builder = builder.with_key_value(kvs.0.to_string(), kvs.1.to_string()); + } + + let svc = builder.register().map_err(|_| ErrorCode::MdnsError)?; + + self.services.borrow_mut().insert(service.name.into(), svc); + + Ok(()) + }) + } + + pub fn remove(&self, name: &str) -> Result<(), Error> { + if self.services.borrow_mut().remove(name).is_some() { + info!("Deregistering mDNS service {}", name); + } + + Ok(()) + } +} + +pub struct MdnsRunner<'a>(&'a Mdns<'a>); + +impl<'a> MdnsRunner<'a> { + pub const fn new(mdns: &'a Mdns<'a>) -> Self { + Self(mdns) + } + + pub async fn run_udp(&mut self) -> Result<(), Error> { + core::future::pending::>().await + } + + pub async fn run(&self, _tx_pipe: &Pipe<'_>, _rx_pipe: &Pipe<'_>) -> Result<(), Error> { + core::future::pending::>().await + } +} + +impl<'a> super::Mdns for Mdns<'a> { + fn add(&self, service: &str, mode: ServiceMode) -> Result<(), Error> { + Mdns::add(self, service, mode) + } + + fn remove(&self, service: &str) -> Result<(), Error> { + Mdns::remove(self, service) + } +} diff --git a/matter/src/mdns/builtin.rs b/matter/src/mdns/builtin.rs new file mode 100644 index 00000000..95c6ad69 --- /dev/null +++ b/matter/src/mdns/builtin.rs @@ -0,0 +1,317 @@ +use core::{cell::RefCell, mem::MaybeUninit, pin::pin}; + +use domain::base::name::FromStrError; +use domain::base::{octets::ParseError, ShortBuf}; +use embassy_futures::select::{select, select3}; +use embassy_time::{Duration, Timer}; +use log::info; + +use crate::data_model::cluster_basic_information::BasicInfoConfig; +use crate::error::{Error, ErrorCode}; +use crate::transport::network::{Address, IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr}; +use crate::transport::packet::{MAX_RX_BUF_SIZE, MAX_TX_BUF_SIZE}; +use crate::transport::pipe::{Chunk, Pipe}; +use crate::transport::udp::UdpListener; +use crate::utils::select::{EitherUnwrap, Notification}; + +use super::{ + proto::{Host, Services}, + Service, ServiceMode, +}; + +const IP_BROADCAST_ADDRS: [(IpAddr, u16); 2] = [ + (IpAddr::V4(Ipv4Addr::new(224, 0, 0, 251)), 5353), + ( + IpAddr::V6(Ipv6Addr::new(0xff02, 0, 0, 0, 0, 0, 0, 0x00fb)), + 5353, + ), +]; + +const IP_BIND_ADDR: (IpAddr, u16) = (IpAddr::V6(Ipv6Addr::UNSPECIFIED), 5353); + +type MdnsTxBuf = MaybeUninit<[u8; MAX_TX_BUF_SIZE]>; +type MdnsRxBuf = MaybeUninit<[u8; MAX_RX_BUF_SIZE]>; + +pub struct Mdns<'a> { + host: Host<'a>, + dev_det: &'a BasicInfoConfig<'a>, + matter_port: u16, + services: RefCell, ServiceMode), 4>>, + notification: Notification, +} + +impl<'a> Mdns<'a> { + #[inline(always)] + pub const fn new( + id: u16, + hostname: &'a str, + ip: [u8; 4], + ipv6: Option<[u8; 16]>, + dev_det: &'a BasicInfoConfig<'a>, + matter_port: u16, + ) -> Self { + Self { + host: Host { + id, + hostname, + ip, + ipv6, + }, + dev_det, + matter_port, + services: RefCell::new(heapless::Vec::new()), + notification: Notification::new(), + } + } + + pub fn add(&self, service: &str, mode: ServiceMode) -> Result<(), Error> { + let mut services = self.services.borrow_mut(); + + services.retain(|(name, _)| name != service); + services + .push((service.into(), mode)) + .map_err(|_| ErrorCode::NoSpace)?; + + self.notification.signal(()); + + Ok(()) + } + + pub fn remove(&self, service: &str) -> Result<(), Error> { + let mut services = self.services.borrow_mut(); + + services.retain(|(name, _)| name != service); + + Ok(()) + } + + pub fn for_each(&self, mut callback: F) -> Result<(), Error> + where + F: FnMut(&Service) -> Result<(), Error>, + { + let services = self.services.borrow(); + + for (service, mode) in &*services { + mode.service(self.dev_det, self.matter_port, service, |service| { + callback(service) + })?; + } + + Ok(()) + } +} + +pub struct MdnsRunner<'a>(&'a Mdns<'a>); + +impl<'a> MdnsRunner<'a> { + pub const fn new(mdns: &'a Mdns<'a>) -> Self { + Self(mdns) + } + + pub async fn run_udp(&mut self) -> Result<(), Error> { + let mut tx_buf = MdnsTxBuf::uninit(); + let mut rx_buf = MdnsRxBuf::uninit(); + + let tx_buf = &mut tx_buf; + let rx_buf = &mut rx_buf; + + let tx_pipe = Pipe::new(unsafe { tx_buf.assume_init_mut() }); + let rx_pipe = Pipe::new(unsafe { rx_buf.assume_init_mut() }); + + let tx_pipe = &tx_pipe; + let rx_pipe = &rx_pipe; + + let mut udp = UdpListener::new(SocketAddr::new(IP_BIND_ADDR.0, IP_BIND_ADDR.1)).await?; + + for (ip, _) in IP_BROADCAST_ADDRS { + udp.join_multicast(ip).await?; + } + + let udp = &udp; + + let mut tx = pin!(async move { + loop { + { + let mut data = tx_pipe.data.lock().await; + + if let Some(chunk) = data.chunk { + udp.send(chunk.addr.unwrap_udp(), &data.buf[chunk.start..chunk.end]) + .await?; + data.chunk = None; + tx_pipe.data_consumed_notification.signal(()); + } + } + + tx_pipe.data_supplied_notification.wait().await; + } + }); + + let mut rx = pin!(async move { + loop { + { + let mut data = rx_pipe.data.lock().await; + + if data.chunk.is_none() { + let (len, addr) = udp.recv(data.buf).await?; + + data.chunk = Some(Chunk { + start: 0, + end: len, + addr: Address::Udp(addr), + }); + rx_pipe.data_supplied_notification.signal(()); + } + } + + rx_pipe.data_consumed_notification.wait().await; + } + }); + + let mut run = pin!(async move { self.run(tx_pipe, rx_pipe).await }); + + select3(&mut tx, &mut rx, &mut run).await.unwrap() + } + + pub async fn run(&self, tx_pipe: &Pipe<'_>, rx_pipe: &Pipe<'_>) -> Result<(), Error> { + let mut broadcast = pin!(self.broadcast(tx_pipe)); + let mut respond = pin!(self.respond(rx_pipe, tx_pipe)); + + select(&mut broadcast, &mut respond).await.unwrap() + } + + #[allow(clippy::await_holding_refcell_ref)] + async fn broadcast(&self, tx_pipe: &Pipe<'_>) -> Result<(), Error> { + loop { + select( + self.0.notification.wait(), + Timer::after(Duration::from_secs(30)), + ) + .await; + + for (addr, port) in IP_BROADCAST_ADDRS { + loop { + let sent = { + let mut data = tx_pipe.data.lock().await; + + if data.chunk.is_none() { + let len = self.0.host.broadcast(&self.0, data.buf, 60)?; + + if len > 0 { + info!("Broadasting mDNS entry to {}:{}", addr, port); + + data.chunk = Some(Chunk { + start: 0, + end: len, + addr: Address::Udp(SocketAddr::new(addr, port)), + }); + + tx_pipe.data_supplied_notification.signal(()); + } + + true + } else { + false + } + }; + + if sent { + break; + } else { + tx_pipe.data_consumed_notification.wait().await; + } + } + } + } + } + + #[allow(clippy::await_holding_refcell_ref)] + async fn respond(&self, rx_pipe: &Pipe<'_>, tx_pipe: &Pipe<'_>) -> Result<(), Error> { + loop { + { + let mut rx_data = rx_pipe.data.lock().await; + + if let Some(rx_chunk) = rx_data.chunk { + let data = &rx_data.buf[rx_chunk.start..rx_chunk.end]; + + loop { + let sent = { + let mut tx_data = tx_pipe.data.lock().await; + + if tx_data.chunk.is_none() { + let len = self.0.host.respond(&self.0, data, tx_data.buf, 60)?; + + if len > 0 { + info!("Replying to mDNS query from {}", rx_chunk.addr); + + tx_data.chunk = Some(Chunk { + start: 0, + end: len, + addr: rx_chunk.addr, + }); + + tx_pipe.data_supplied_notification.signal(()); + } + + true + } else { + false + } + }; + + if sent { + break; + } else { + tx_pipe.data_consumed_notification.wait().await; + } + } + + // info!("Got mDNS query"); + + rx_data.chunk = None; + rx_pipe.data_consumed_notification.signal(()); + } + } + + rx_pipe.data_supplied_notification.wait().await; + } + } +} + +impl<'a> super::Mdns for Mdns<'a> { + fn add(&self, service: &str, mode: ServiceMode) -> Result<(), Error> { + Mdns::add(self, service, mode) + } + + fn remove(&self, service: &str) -> Result<(), Error> { + Mdns::remove(self, service) + } +} + +impl<'a> Services for Mdns<'a> { + type Error = crate::error::Error; + + fn for_each(&self, callback: F) -> Result<(), Error> + where + F: FnMut(&Service) -> Result<(), Error>, + { + Mdns::for_each(self, callback) + } +} + +impl From for Error { + fn from(_e: ShortBuf) -> Self { + Self::new(ErrorCode::NoSpace) + } +} + +impl From for Error { + fn from(_e: ParseError) -> Self { + Self::new(ErrorCode::MdnsError) + } +} + +impl From for Error { + fn from(_e: FromStrError) -> Self { + Self::new(ErrorCode::MdnsError) + } +} diff --git a/matter/src/mdns/proto.rs b/matter/src/mdns/proto.rs new file mode 100644 index 00000000..6fac2c71 --- /dev/null +++ b/matter/src/mdns/proto.rs @@ -0,0 +1,508 @@ +use core::fmt::Write; +use core::str::FromStr; + +use domain::{ + base::{ + header::Flags, + iana::Class, + message_builder::AnswerBuilder, + name::FromStrError, + octets::{Octets256, Octets64, OctetsBuilder, ParseError}, + Dname, Message, MessageBuilder, Record, Rtype, ShortBuf, ToDname, + }, + rdata::{Aaaa, Ptr, Srv, Txt, A}, +}; +use log::trace; + +pub trait Services { + type Error: From + From + From; + + fn for_each(&self, callback: F) -> Result<(), Self::Error> + where + F: FnMut(&Service) -> Result<(), Self::Error>; +} + +impl Services for &mut T +where + T: Services, +{ + type Error = T::Error; + + fn for_each(&self, callback: F) -> Result<(), Self::Error> + where + F: FnMut(&Service) -> Result<(), Self::Error>, + { + (**self).for_each(callback) + } +} + +impl Services for &T +where + T: Services, +{ + type Error = T::Error; + + fn for_each(&self, callback: F) -> Result<(), Self::Error> + where + F: FnMut(&Service) -> Result<(), Self::Error>, + { + (**self).for_each(callback) + } +} + +pub struct Host<'a> { + pub id: u16, + pub hostname: &'a str, + pub ip: [u8; 4], + pub ipv6: Option<[u8; 16]>, +} + +impl<'a> Host<'a> { + pub fn broadcast( + &self, + services: T, + buf: &mut [u8], + ttl_sec: u32, + ) -> Result { + let buf = Buf(buf, 0); + + let message = MessageBuilder::from_target(buf)?; + + let mut answer = message.answer(); + + self.set_broadcast(services, &mut answer, ttl_sec)?; + + let buf = answer.finish(); + + Ok(buf.1) + } + + pub fn respond( + &self, + services: T, + data: &[u8], + buf: &mut [u8], + ttl_sec: u32, + ) -> Result { + let buf = Buf(buf, 0); + + let message = MessageBuilder::from_target(buf)?; + + let mut answer = message.answer(); + + if self.set_response(data, services, &mut answer, ttl_sec)? { + let buf = answer.finish(); + + Ok(buf.1) + } else { + Ok(0) + } + } + + fn set_broadcast( + &self, + services: F, + answer: &mut AnswerBuilder, + ttl_sec: u32, + ) -> Result<(), F::Error> + where + T: OctetsBuilder + AsMut<[u8]>, + F: Services, + { + self.set_header(answer); + + self.add_ipv4(answer, ttl_sec)?; + self.add_ipv6(answer, ttl_sec)?; + + services.for_each(|service| { + service.add_service(answer, self.hostname, ttl_sec)?; + service.add_service_type(answer, ttl_sec)?; + service.add_service_subtypes(answer, ttl_sec)?; + service.add_txt(answer, ttl_sec)?; + + Ok(()) + })?; + + Ok(()) + } + + fn set_response( + &self, + data: &[u8], + services: F, + answer: &mut AnswerBuilder, + ttl_sec: u32, + ) -> Result + where + T: OctetsBuilder + AsMut<[u8]>, + F: Services, + { + self.set_header(answer); + + let message = Message::from_octets(data)?; + + let mut replied = false; + + for question in message.question() { + trace!("Handling question {:?}", question); + + let question = question?; + + match question.qtype() { + Rtype::A + if question + .qname() + .name_eq(&Host::host_fqdn(self.hostname, true)?) => + { + self.add_ipv4(answer, ttl_sec)?; + replied = true; + } + Rtype::Aaaa + if question + .qname() + .name_eq(&Host::host_fqdn(self.hostname, true)?) => + { + self.add_ipv6(answer, ttl_sec)?; + replied = true; + } + Rtype::Srv => { + services.for_each(|service| { + if question.qname().name_eq(&service.service_fqdn(true)?) { + self.add_ipv4(answer, ttl_sec)?; + self.add_ipv6(answer, ttl_sec)?; + service.add_service(answer, self.hostname, ttl_sec)?; + replied = true; + } + + Ok(()) + })?; + } + Rtype::Ptr => { + services.for_each(|service| { + if question.qname().name_eq(&Service::dns_sd_fqdn(true)?) { + service.add_service_type(answer, ttl_sec)?; + replied = true; + } else if question.qname().name_eq(&service.service_type_fqdn(true)?) { + // TODO + self.add_ipv4(answer, ttl_sec)?; + self.add_ipv6(answer, ttl_sec)?; + service.add_service(answer, self.hostname, ttl_sec)?; + service.add_service_type(answer, ttl_sec)?; + service.add_service_subtypes(answer, ttl_sec)?; + service.add_txt(answer, ttl_sec)?; + replied = true; + } + + Ok(()) + })?; + } + Rtype::Txt => { + services.for_each(|service| { + if question.qname().name_eq(&service.service_fqdn(true)?) { + service.add_txt(answer, ttl_sec)?; + replied = true; + } + + Ok(()) + })?; + } + Rtype::Any => { + // A / AAAA + if question + .qname() + .name_eq(&Host::host_fqdn(self.hostname, true)?) + { + self.add_ipv4(answer, ttl_sec)?; + self.add_ipv6(answer, ttl_sec)?; + replied = true; + } + + // PTR + services.for_each(|service| { + if question.qname().name_eq(&Service::dns_sd_fqdn(true)?) { + service.add_service_type(answer, ttl_sec)?; + replied = true; + } else if question.qname().name_eq(&service.service_type_fqdn(true)?) { + // TODO + self.add_ipv4(answer, ttl_sec)?; + self.add_ipv6(answer, ttl_sec)?; + service.add_service(answer, self.hostname, ttl_sec)?; + service.add_service_type(answer, ttl_sec)?; + service.add_service_subtypes(answer, ttl_sec)?; + service.add_txt(answer, ttl_sec)?; + replied = true; + } + + Ok(()) + })?; + + // SRV + services.for_each(|service| { + if question.qname().name_eq(&service.service_fqdn(true)?) { + self.add_ipv4(answer, ttl_sec)?; + self.add_ipv6(answer, ttl_sec)?; + service.add_service(answer, self.hostname, ttl_sec)?; + replied = true; + } + + Ok(()) + })?; + } + _ => (), + } + } + + Ok(replied) + } + + fn set_header>(&self, answer: &mut AnswerBuilder) { + let header = answer.header_mut(); + header.set_id(self.id); + header.set_opcode(domain::base::iana::Opcode::Query); + header.set_rcode(domain::base::iana::Rcode::NoError); + + let mut flags = Flags::new(); + flags.qr = true; + flags.aa = true; + header.set_flags(flags); + } + + fn add_ipv4>( + &self, + answer: &mut AnswerBuilder, + ttl_sec: u32, + ) -> Result<(), ShortBuf> { + answer.push(Record::, A>::new( + Self::host_fqdn(self.hostname, false).unwrap(), + Class::In, + ttl_sec, + A::from_octets(self.ip[0], self.ip[1], self.ip[2], self.ip[3]), + )) + } + + fn add_ipv6>( + &self, + answer: &mut AnswerBuilder, + ttl_sec: u32, + ) -> Result<(), ShortBuf> { + if let Some(ip) = &self.ipv6 { + answer.push(Record::, Aaaa>::new( + Self::host_fqdn(self.hostname, false).unwrap(), + Class::In, + ttl_sec, + Aaaa::new((*ip).into()), + )) + } else { + Ok(()) + } + } + + fn host_fqdn(hostname: &str, suffix: bool) -> Result, FromStrError> { + let suffix = if suffix { "." } else { "" }; + + let mut host_fqdn = heapless::String::<60>::new(); + write!(host_fqdn, "{}.local{}", hostname, suffix,).unwrap(); + + Dname::from_str(&host_fqdn) + } +} + +pub struct Service<'a> { + pub name: &'a str, + pub service: &'a str, + pub protocol: &'a str, + pub port: u16, + pub service_subtypes: &'a [&'a str], + pub txt_kvs: &'a [(&'a str, &'a str)], +} + +impl<'a> Service<'a> { + fn add_service>( + &self, + answer: &mut AnswerBuilder, + hostname: &str, + ttl_sec: u32, + ) -> Result<(), ShortBuf> { + answer.push(Record::, Srv<_>>::new( + self.service_fqdn(false).unwrap(), + Class::In, + ttl_sec, + Srv::new(0, 0, self.port, Host::host_fqdn(hostname, false).unwrap()), + )) + } + + fn add_service_type>( + &self, + answer: &mut AnswerBuilder, + ttl_sec: u32, + ) -> Result<(), ShortBuf> { + answer.push(Record::, Ptr<_>>::new( + Self::dns_sd_fqdn(false).unwrap(), + Class::In, + ttl_sec, + Ptr::new(self.service_type_fqdn(false).unwrap()), + ))?; + + answer.push(Record::, Ptr<_>>::new( + self.service_type_fqdn(false).unwrap(), + Class::In, + ttl_sec, + Ptr::new(self.service_fqdn(false).unwrap()), + )) + } + + fn add_service_subtypes>( + &self, + answer: &mut AnswerBuilder, + ttl_sec: u32, + ) -> Result<(), ShortBuf> { + for service_subtype in self.service_subtypes { + self.add_service_subtype(answer, service_subtype, ttl_sec)?; + } + + Ok(()) + } + + fn add_service_subtype>( + &self, + answer: &mut AnswerBuilder, + service_subtype: &str, + ttl_sec: u32, + ) -> Result<(), ShortBuf> { + answer.push(Record::, Ptr<_>>::new( + Self::dns_sd_fqdn(false).unwrap(), + Class::In, + ttl_sec, + Ptr::new(self.service_subtype_fqdn(service_subtype, false).unwrap()), + ))?; + + answer.push(Record::, Ptr<_>>::new( + self.service_subtype_fqdn(service_subtype, false).unwrap(), + Class::In, + ttl_sec, + Ptr::new(self.service_fqdn(false).unwrap()), + )) + } + + fn add_txt>( + &self, + answer: &mut AnswerBuilder, + ttl_sec: u32, + ) -> Result<(), ShortBuf> { + // only way I found to create multiple parts in a Txt + // each slice is the length and then the data + let mut octets = Octets256::new(); + //octets.append_slice(&[1u8, b'X'])?; + //octets.append_slice(&[2u8, b'A', b'B'])?; + //octets.append_slice(&[0u8])?; + for (k, v) in self.txt_kvs { + octets.append_slice(&[(k.len() + v.len() + 1) as u8])?; + octets.append_slice(k.as_bytes())?; + octets.append_slice(&[b'='])?; + octets.append_slice(v.as_bytes())?; + } + + let txt = Txt::from_octets(&mut octets).unwrap(); + + answer.push(Record::, Txt<_>>::new( + self.service_fqdn(false).unwrap(), + Class::In, + ttl_sec, + txt, + )) + } + + fn service_fqdn(&self, suffix: bool) -> Result, FromStrError> { + let suffix = if suffix { "." } else { "" }; + + let mut service_fqdn = heapless::String::<60>::new(); + write!( + service_fqdn, + "{}.{}.{}.local{}", + self.name, self.service, self.protocol, suffix, + ) + .unwrap(); + + Dname::from_str(&service_fqdn) + } + + fn service_type_fqdn(&self, suffix: bool) -> Result, FromStrError> { + let suffix = if suffix { "." } else { "" }; + + let mut service_type_fqdn = heapless::String::<60>::new(); + write!( + service_type_fqdn, + "{}.{}.local{}", + self.service, self.protocol, suffix, + ) + .unwrap(); + + Dname::from_str(&service_type_fqdn) + } + + fn service_subtype_fqdn( + &self, + service_subtype: &str, + suffix: bool, + ) -> Result, FromStrError> { + let suffix = if suffix { "." } else { "" }; + + let mut service_subtype_fqdn = heapless::String::<40>::new(); + write!( + service_subtype_fqdn, + "{}._sub.{}.{}.local{}", + service_subtype, self.service, self.protocol, suffix, + ) + .unwrap(); + + Dname::from_str(&service_subtype_fqdn) + } + + fn dns_sd_fqdn(suffix: bool) -> Result, FromStrError> { + if suffix { + Dname::from_str("_services._dns-sd._udp.local.") + } else { + Dname::from_str("_services._dns-sd._udp.local") + } + } +} + +struct Buf<'a>(pub &'a mut [u8], pub usize); + +impl<'a> OctetsBuilder for Buf<'a> { + type Octets = Self; + + fn append_slice(&mut self, slice: &[u8]) -> Result<(), ShortBuf> { + if self.1 + slice.len() <= self.0.len() { + let end = self.1 + slice.len(); + self.0[self.1..end].copy_from_slice(slice); + self.1 = end; + + Ok(()) + } else { + Err(ShortBuf) + } + } + + fn truncate(&mut self, len: usize) { + self.1 = len; + } + + fn freeze(self) -> Self::Octets { + self + } + + fn len(&self) -> usize { + self.1 + } + + fn is_empty(&self) -> bool { + self.1 == 0 + } +} + +impl<'a> AsMut<[u8]> for Buf<'a> { + fn as_mut(&mut self) -> &mut [u8] { + &mut self.0[..self.1] + } +} diff --git a/matter/src/secure_channel/core.rs b/matter/src/secure_channel/core.rs index 523278e6..0ad17ed3 100644 --- a/matter/src/secure_channel/core.rs +++ b/matter/src/secure_channel/core.rs @@ -20,7 +20,7 @@ use core::{borrow::Borrow, cell::RefCell}; use crate::{ error::*, fabric::FabricMgr, - mdns::MdnsMgr, + mdns::Mdns, secure_channel::common::*, tlv, transport::{proto_ctx::ProtoCtx, session::CloneData}, @@ -36,7 +36,7 @@ use super::{case::Case, pake::PaseMgr}; pub struct SecureChannel<'a> { case: Case<'a>, pase: &'a RefCell, - mdns: &'a MdnsMgr<'a>, + mdns: &'a dyn Mdns, } impl<'a> SecureChannel<'a> { @@ -44,7 +44,7 @@ impl<'a> SecureChannel<'a> { pub fn new< T: Borrow> + Borrow> - + Borrow> + + Borrow + Borrow + Borrow, >( @@ -62,7 +62,7 @@ impl<'a> SecureChannel<'a> { pub fn wrap( pase: &'a RefCell, fabric: &'a RefCell, - mdns: &'a MdnsMgr<'a>, + mdns: &'a dyn Mdns, rand: Rand, ) -> Self { Self { diff --git a/matter/src/secure_channel/pake.rs b/matter/src/secure_channel/pake.rs index 60920d03..79f7d2cd 100644 --- a/matter/src/secure_channel/pake.rs +++ b/matter/src/secure_channel/pake.rs @@ -24,7 +24,7 @@ use super::{ use crate::{ crypto, error::{Error, ErrorCode}, - mdns::{MdnsMgr, ServiceMode}, + mdns::{Mdns, ServiceMode}, secure_channel::common::OpCode, tlv::{self, get_root_node_struct, FromTLV, OctetStr, TLVWriter, TagType, ToTLV}, transport::{ @@ -39,7 +39,7 @@ use log::{error, info}; #[allow(clippy::large_enum_variant)] enum PaseMgrState { - Enabled(Pake, heapless::String<16>, u16), + Enabled(Pake, heapless::String<16>), Disabled, } @@ -60,14 +60,14 @@ impl PaseMgr { } pub fn is_pase_session_enabled(&self) -> bool { - matches!(&self.state, PaseMgrState::Enabled(_, _, _)) + matches!(&self.state, PaseMgrState::Enabled(_, _)) } pub fn enable_pase_session( &mut self, verifier: VerifierData, discriminator: u16, - mdns: &MdnsMgr, + mdns: &dyn Mdns, ) -> Result<(), Error> { let mut buf = [0; 8]; (self.rand)(&mut buf); @@ -76,25 +76,21 @@ impl PaseMgr { let mut mdns_service_name = heapless::String::<16>::new(); write!(&mut mdns_service_name, "{:016X}", num).unwrap(); - mdns.publish_service( + mdns.add( &mdns_service_name, ServiceMode::Commissionable(discriminator), )?; self.state = PaseMgrState::Enabled( Pake::new(verifier, self.epoch, self.rand), mdns_service_name, - discriminator, ); Ok(()) } - pub fn disable_pase_session(&mut self, mdns: &MdnsMgr) -> Result<(), Error> { - if let PaseMgrState::Enabled(_, mdns_service_name, discriminator) = &self.state { - mdns.unpublish_service( - mdns_service_name, - ServiceMode::Commissionable(*discriminator), - )?; + pub fn disable_pase_session(&mut self, mdns: &dyn Mdns) -> Result<(), Error> { + if let PaseMgrState::Enabled(_, mdns_service_name) = &self.state { + mdns.remove(mdns_service_name)?; } self.state = PaseMgrState::Disabled; @@ -108,7 +104,7 @@ impl PaseMgr { where F: FnOnce(&mut Pake, &mut ProtoCtx) -> Result, { - if let PaseMgrState::Enabled(pake, _, _) = &mut self.state { + if let PaseMgrState::Enabled(pake, _) = &mut self.state { let data = f(pake, ctx)?; Ok(Some(data)) @@ -134,7 +130,7 @@ impl PaseMgr { pub fn pasepake3_handler( &mut self, ctx: &mut ProtoCtx, - mdns: &MdnsMgr, + mdns: &dyn Mdns, ) -> Result<(bool, Option), Error> { let clone_data = self.if_enabled(ctx, |pake, ctx| pake.handle_pasepake3(ctx))?; self.disable_pase_session(mdns)?; diff --git a/matter/src/transport/udp.rs b/matter/src/transport/udp.rs index 7cf52889..5308462b 100644 --- a/matter/src/transport/udp.rs +++ b/matter/src/transport/udp.rs @@ -27,7 +27,7 @@ mod smol_udp { use log::{debug, info, warn}; use smol::net::UdpSocket; - use crate::transport::network::SocketAddr; + use crate::transport::network::{IpAddr, Ipv4Addr, SocketAddr}; pub struct UdpListener { socket: UdpSocket, @@ -44,9 +44,18 @@ mod smol_udp { Ok(listener) } - pub async fn recv(&self, in_buf: &mut [u8]) -> Result<(usize, SocketAddr), Error> { - info!("Waiting for incoming packets"); + pub async fn join_multicast(&mut self, ip_addr: IpAddr) -> Result<(), Error> { + match ip_addr { + IpAddr::V4(ip_addr) => self + .socket + .join_multicast_v4(ip_addr, Ipv4Addr::UNSPECIFIED)?, + IpAddr::V6(ip_addr) => self.socket.join_multicast_v6(&ip_addr, 0)?, + } + + Ok(()) + } + pub async fn recv(&self, in_buf: &mut [u8]) -> Result<(usize, SocketAddr), Error> { let (size, addr) = self.socket.recv_from(in_buf).await.map_err(|e| { warn!("Error on the network: {:?}", e); ErrorCode::Network @@ -96,6 +105,10 @@ mod dummy_udp { Ok(listener) } + pub async fn join_multicast(&mut self, ip_addr: IpAddr) -> Result<(), Error> { + Ok(()) + } + pub async fn recv(&self, _in_buf: &mut [u8]) -> Result<(usize, SocketAddr), Error> { info!("Pretending to wait for incoming packets (looping forever)"); From e8babedd87932f43e3fa9c75c12dd74c4d3a67c6 Mon Sep 17 00:00:00 2001 From: ivmarkov Date: Mon, 12 Jun 2023 11:41:33 +0000 Subject: [PATCH 56/72] Support for ESP-IDF build --- Cargo.toml | 2 +- README.md | 24 +++++++++++++++++++++--- examples/onoff_light/src/main.rs | 31 ++++++++++++++++++------------- matter/Cargo.toml | 19 +++++++++++++------ matter/build.rs | 11 +++++++++++ matter/src/transport/udp.rs | 2 ++ 6 files changed, 66 insertions(+), 23 deletions(-) create mode 100644 matter/build.rs diff --git a/Cargo.toml b/Cargo.toml index 7561a523..11f05af5 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,5 +1,5 @@ [workspace] -members = ["matter", "matter_macro_derive", "tools/tlv_tool"] +members = ["matter", "matter_macro_derive"] exclude = ["examples/*"] diff --git a/README.md b/README.md index 80cb2b24..86de8807 100644 --- a/README.md +++ b/README.md @@ -47,13 +47,31 @@ Building the library: $ cargo build ``` -Building the example: +Building and running the example (Linux, MacOS X): ``` -$ RUST_LOG="matter" cargo run --example onoff_light +$ cargo run --example onoff_light ``` -With the chip-tool (the current tool for testing Matter) use the Ethernet commissioning mechanism: +Building the example (Espressif's ESP-IDF): +* Install all build prerequisites described [here](https://github.com/esp-rs/esp-idf-template#prerequisites) +* Build with the following command line: +``` +export MCU=esp32; export CARGO_TARGET_XTENSA_ESP32_ESPIDF_LINKER=ldproxy; export RUSTFLAGS="-C default-linker-libraries"; export WIFI_SSID=ssid;export WIFI_PASS=pass; cargo build --example onoff_light --no-default-features --features std,crypto_rustcrypto --target xtensa-esp32-espidf -Zbuild-std=std,panic_abort +``` +* If you are building for a different Espressif MCU, change the `MCU` variable, the `xtensa-esp32-espidf` target and the name of the `CARGO_TARGET__LINKER` variable to match your MCU and its Rust target. Available Espressif MCUs and targets are: + * esp32 / xtensa-esp32-espidf + * esp32s2 / xtensa-esp32s2-espidf + * esp32s3 / xtensa-esp32s3-espidf + * esp32c3 / riscv32imc-esp-espidf + * esp32c5 / riscv32imc-esp-espidf + * esp32c6 / risxcv32imac-esp-espidf +* Put in `WIFI_SSID` / `WIFI_PASS` the SSID & password for your wireless router +* Flash using the `espflash` utility described in the build prerequsites' link above + +## Test + +With the `chip-tool` (the current tool for testing Matter) use the Ethernet commissioning mechanism: ``` $ chip-tool pairing code 12344321 diff --git a/examples/onoff_light/src/main.rs b/examples/onoff_light/src/main.rs index e26f20f1..069f1dd6 100644 --- a/examples/onoff_light/src/main.rs +++ b/examples/onoff_light/src/main.rs @@ -31,7 +31,6 @@ use matter::data_model::system_model::descriptor; use matter::error::Error; use matter::interaction_model::core::InteractionModel; use matter::mdns::{DefaultMdns, DefaultMdnsRunner}; -use matter::persist; use matter::secure_channel::spake2p::VerifierData; use matter::transport::network::{Address, IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr}; use matter::transport::{ @@ -50,7 +49,6 @@ fn main() -> Result<(), Error> { .unwrap(); thread.join().unwrap() - // run() } #[cfg(not(feature = "std"))] @@ -105,17 +103,21 @@ fn run() -> Result<(), Error> { let psm_path = std::env::temp_dir().join("matter-iot"); info!("Persisting from/to {}", psm_path.display()); - let psm = persist::FilePsm::new(psm_path)?; + #[cfg(all(feature = "std", not(target_os = "espidf")))] + let psm = matter::persist::FilePsm::new(psm_path)?; let mut buf = [0; 4096]; let buf = &mut buf; - if let Some(data) = psm.load("acls", buf)? { - matter.load_acls(data)?; - } + #[cfg(all(feature = "std", not(target_os = "espidf")))] + { + if let Some(data) = psm.load("acls", buf)? { + matter.load_acls(data)?; + } - if let Some(data) = psm.load("fabrics", buf)? { - matter.load_fabrics(data)?; + if let Some(data) = psm.load("fabrics", buf)? { + matter.load_fabrics(data)?; + } } let mut transport = Transport::new(&matter); @@ -180,12 +182,15 @@ fn run() -> Result<(), Error> { } } - if let Some(data) = transport.matter().store_fabrics(buf)? { - psm.store("fabrics", data)?; - } + #[cfg(all(feature = "std", not(target_os = "espidf")))] + { + if let Some(data) = transport.matter().store_fabrics(buf)? { + psm.store("fabrics", data)?; + } - if let Some(data) = transport.matter().store_acls(buf)? { - psm.store("acls", data)?; + if let Some(data) = transport.matter().store_acls(buf)? { + psm.store("acls", data)?; + } } } diff --git a/matter/Cargo.toml b/matter/Cargo.toml index 2f859b7c..5b8816d8 100644 --- a/matter/Cargo.toml +++ b/matter/Cargo.toml @@ -16,7 +16,6 @@ path = "src/lib.rs" [features] default = ["os", "crypto_rustcrypto"] -#default = ["crypto_rustcrypto"] os = ["std", "backtrace", "env_logger", "nix", "critical-section/std", "embassy-sync/std", "embassy-time/std"] std = ["alloc", "rand", "qrcode", "async-io", "smol", "esp-idf-sys/std"] backtrace = [] @@ -51,7 +50,6 @@ domain = { version = "0.7.2", default_features = false, features = ["heapless"] # STD-only dependencies rand = { version = "0.8.5", optional = true } qrcode = { version = "0.12", default-features = false, optional = true } # Print QR code -astro-dnssd = { version = "0.3", optional = true } # On Linux needs avahi-compat-libdns_sd, i.e. on Ubuntu/Debian do `sudo apt-get install libavahi-compat-libdnssd-dev` smol = { version = "1.2", optional = true } # =1.2 for compatibility with ESP IDF async-io = { version = "=1.12", optional = true } # =1.2 for compatibility with ESP IDF @@ -72,18 +70,27 @@ crypto-bigint = { version = "0.4", default-features = false, optional = true } rand_core = { version = "0.6", default-features = false, optional = true } x509-cert = { version = "0.2.0", default-features = false, features = ["pem"], optional = true } # TODO: requires `alloc` +[target.'cfg(target_os = "macos")'.dependencies] +astro-dnssd = { version = "0.3" } + [target.'cfg(not(target_os = "espidf"))'.dependencies] mbedtls = { git = "https://github.com/fortanix/rust-mbedtls", optional = true } env_logger = { version = "0.10.0", optional = true } nix = { version = "0.26", features = ["net"], optional = true } [target.'cfg(target_os = "espidf")'.dependencies] -esp-idf-sys = { version = "0.33", default-features = false, features = ["native"] } +esp-idf-sys = { version = "0.33", default-features = false, features = ["native", "binstart"] } +esp-idf-hal = { version = "0.41", features = ["embassy-sync", "critical-section"] } +esp-idf-svc = { version = "0.46", features = ["embassy-time-driver"] } +embedded-svc = "0.25" + +[build-dependencies] +embuild = "0.31.2" [[example]] name = "onoff_light" path = "../examples/onoff_light/src/main.rs" -[[example]] -name = "speaker" -path = "../examples/speaker/src/main.rs" +# [[example]] +# name = "speaker" +# path = "../examples/speaker/src/main.rs" diff --git a/matter/build.rs b/matter/build.rs new file mode 100644 index 00000000..9e17e102 --- /dev/null +++ b/matter/build.rs @@ -0,0 +1,11 @@ +use std::env::var; + +// Necessary because of this issue: https://github.com/rust-lang/cargo/issues/9641 +fn main() -> Result<(), Box> { + if var("TARGET").unwrap().ends_with("-espidf") { + embuild::build::CfgArgs::output_propagated("ESP_IDF")?; + embuild::build::LinkArgs::output_propagated("ESP_IDF")?; + } + + Ok(()) +} diff --git a/matter/src/transport/udp.rs b/matter/src/transport/udp.rs index 5308462b..0aa65731 100644 --- a/matter/src/transport/udp.rs +++ b/matter/src/transport/udp.rs @@ -52,6 +52,8 @@ mod smol_udp { IpAddr::V6(ip_addr) => self.socket.join_multicast_v6(&ip_addr, 0)?, } + info!("Joining multicast on {:?}", ip_addr); + Ok(()) } From 62aa69202fbcf28a1b9284f73ea4944c690dcca2 Mon Sep 17 00:00:00 2001 From: ivmarkov Date: Tue, 13 Jun 2023 07:02:37 +0000 Subject: [PATCH 57/72] Workaround broken join_multicast_v4 on ESP-IDF --- examples/onoff_light/src/main.rs | 13 +++--- matter/src/mdns/astro.rs | 1 + matter/src/mdns/builtin.rs | 32 +++++++------- matter/src/transport/udp.rs | 72 ++++++++++++++++++++++++++++---- 4 files changed, 90 insertions(+), 28 deletions(-) diff --git a/examples/onoff_light/src/main.rs b/examples/onoff_light/src/main.rs index 069f1dd6..7115a287 100644 --- a/examples/onoff_light/src/main.rs +++ b/examples/onoff_light/src/main.rs @@ -77,13 +77,14 @@ fn run() -> Result<(), Error> { device_name: "OnOff Light", }; - let (ipv4_addr, ipv6_addr) = initialize_network()?; + let (ipv4_addr, ipv6_addr, interface) = initialize_network()?; let mdns = DefaultMdns::new( 0, "matter-demo", ipv4_addr.octets(), Some(ipv6_addr.octets()), + interface, &dev_det, matter::MATTER_PORT, ); @@ -231,7 +232,7 @@ fn initialize_logger() { #[cfg(not(target_os = "espidf"))] #[inline(never)] -fn initialize_network() -> Result<(Ipv4Addr, Ipv6Addr), Error> { +fn initialize_network() -> Result<(Ipv4Addr, Ipv6Addr, u32), Error> { use log::error; use matter::error::ErrorCode; use nix::{net::if_::InterfaceFlags, sys::socket::SockaddrIn6}; @@ -276,7 +277,7 @@ fn initialize_network() -> Result<(Ipv4Addr, Ipv6Addr), Error> { iname, ip, ipv6 ); - Ok((ip, ipv6)) + Ok((ip, ipv6, 0 as _)) } #[cfg(target_os = "espidf")] @@ -287,7 +288,7 @@ fn initialize_logger() { #[cfg(target_os = "espidf")] #[inline(never)] -fn initialize_network() -> Result<(Ipv4Addr, Ipv6Addr), Error> { +fn initialize_network() -> Result<(Ipv4Addr, Ipv6Addr, u32), Error> { use core::time::Duration; use embedded_svc::wifi::{AuthMethod, ClientConfiguration, Configuration}; @@ -379,9 +380,11 @@ fn initialize_network() -> Result<(Ipv4Addr, Ipv6Addr), Error> { ipv6.addr[3].to_le_bytes()[3], ]; + let interface = wifi.sta_netif().get_index(); + // Not OK of course, but for a demo this is good enough // Wifi will continue to be available and working in the background core::mem::forget(wifi); - Ok((ipv4_octets.into(), ipv6_octets.into())) + Ok((ipv4_octets.into(), ipv6_octets.into(), interface)) } diff --git a/matter/src/mdns/astro.rs b/matter/src/mdns/astro.rs index 12426cb6..e7ae4c20 100644 --- a/matter/src/mdns/astro.rs +++ b/matter/src/mdns/astro.rs @@ -23,6 +23,7 @@ impl<'a> Mdns<'a> { _hostname: &str, _ip: [u8; 4], _ipv6: Option<[u8; 16]>, + _interface: u32, dev_det: &'a BasicInfoConfig<'a>, matter_port: u16, ) -> Self { diff --git a/matter/src/mdns/builtin.rs b/matter/src/mdns/builtin.rs index 95c6ad69..7b6f8912 100644 --- a/matter/src/mdns/builtin.rs +++ b/matter/src/mdns/builtin.rs @@ -19,21 +19,19 @@ use super::{ Service, ServiceMode, }; -const IP_BROADCAST_ADDRS: [(IpAddr, u16); 2] = [ - (IpAddr::V4(Ipv4Addr::new(224, 0, 0, 251)), 5353), - ( - IpAddr::V6(Ipv6Addr::new(0xff02, 0, 0, 0, 0, 0, 0, 0x00fb)), - 5353, - ), -]; +const IP_BIND_ADDR: IpAddr = IpAddr::V6(Ipv6Addr::UNSPECIFIED); -const IP_BIND_ADDR: (IpAddr, u16) = (IpAddr::V6(Ipv6Addr::UNSPECIFIED), 5353); +const IP_BROADCAST_ADDR: Ipv4Addr = Ipv4Addr::new(224, 0, 0, 251); +const IPV6_BROADCAST_ADDR: Ipv6Addr = Ipv6Addr::new(0xff02, 0, 0, 0, 0, 0, 0, 0x00fb); + +const PORT: u16 = 5353; type MdnsTxBuf = MaybeUninit<[u8; MAX_TX_BUF_SIZE]>; type MdnsRxBuf = MaybeUninit<[u8; MAX_RX_BUF_SIZE]>; pub struct Mdns<'a> { host: Host<'a>, + interface: u32, dev_det: &'a BasicInfoConfig<'a>, matter_port: u16, services: RefCell, ServiceMode), 4>>, @@ -47,6 +45,7 @@ impl<'a> Mdns<'a> { hostname: &'a str, ip: [u8; 4], ipv6: Option<[u8; 16]>, + interface: u32, dev_det: &'a BasicInfoConfig<'a>, matter_port: u16, ) -> Self { @@ -57,6 +56,7 @@ impl<'a> Mdns<'a> { ip, ipv6, }, + interface, dev_det, matter_port, services: RefCell::new(heapless::Vec::new()), @@ -121,11 +121,10 @@ impl<'a> MdnsRunner<'a> { let tx_pipe = &tx_pipe; let rx_pipe = &rx_pipe; - let mut udp = UdpListener::new(SocketAddr::new(IP_BIND_ADDR.0, IP_BIND_ADDR.1)).await?; + let mut udp = UdpListener::new(SocketAddr::new(IP_BIND_ADDR, PORT)).await?; - for (ip, _) in IP_BROADCAST_ADDRS { - udp.join_multicast(ip).await?; - } + udp.join_multicast_v6(IPV6_BROADCAST_ADDR, self.0.interface)?; + udp.join_multicast_v4(IP_BROADCAST_ADDR, Ipv4Addr::from(self.0.host.ip))?; let udp = &udp; @@ -188,7 +187,10 @@ impl<'a> MdnsRunner<'a> { ) .await; - for (addr, port) in IP_BROADCAST_ADDRS { + for addr in [ + IpAddr::V4(IP_BROADCAST_ADDR), + IpAddr::V6(IPV6_BROADCAST_ADDR), + ] { loop { let sent = { let mut data = tx_pipe.data.lock().await; @@ -197,12 +199,12 @@ impl<'a> MdnsRunner<'a> { let len = self.0.host.broadcast(&self.0, data.buf, 60)?; if len > 0 { - info!("Broadasting mDNS entry to {}:{}", addr, port); + info!("Broadasting mDNS entry to {}:{}", addr, PORT); data.chunk = Some(Chunk { start: 0, end: len, - addr: Address::Udp(SocketAddr::new(addr, port)), + addr: Address::Udp(SocketAddr::new(addr, PORT)), }); tx_pipe.data_supplied_notification.signal(()); diff --git a/matter/src/transport/udp.rs b/matter/src/transport/udp.rs index 0aa65731..bf3f36f5 100644 --- a/matter/src/transport/udp.rs +++ b/matter/src/transport/udp.rs @@ -27,7 +27,7 @@ mod smol_udp { use log::{debug, info, warn}; use smol::net::UdpSocket; - use crate::transport::network::{IpAddr, Ipv4Addr, SocketAddr}; + use crate::transport::network::{Ipv4Addr, Ipv6Addr, SocketAddr}; pub struct UdpListener { socket: UdpSocket, @@ -44,15 +44,71 @@ mod smol_udp { Ok(listener) } - pub async fn join_multicast(&mut self, ip_addr: IpAddr) -> Result<(), Error> { - match ip_addr { - IpAddr::V4(ip_addr) => self - .socket - .join_multicast_v4(ip_addr, Ipv4Addr::UNSPECIFIED)?, - IpAddr::V6(ip_addr) => self.socket.join_multicast_v6(&ip_addr, 0)?, + pub fn join_multicast_v6( + &mut self, + multiaddr: Ipv6Addr, + interface: u32, + ) -> Result<(), Error> { + self.socket.join_multicast_v6(&multiaddr, interface)?; + + info!("Joined IPV6 multicast {}/{}", multiaddr, interface); + + Ok(()) + } + + pub fn join_multicast_v4( + &mut self, + multiaddr: Ipv4Addr, + interface: Ipv4Addr, + ) -> Result<(), Error> { + #[cfg(not(target_os = "espidf"))] + self.socket.join_multicast_v4(multiaddr, interface)?; + + // join_multicast_v4() is broken for ESP-IDF, most likely due to wrong `ip_mreq` signature in the `libc` crate + // Note that also most *_multicast_v4 and *_multicast_v6 methods are broken as well in Rust STD for the ESP-IDF + // due to mismatch w.r.t. sizes (u8 expected but u32 passed to setsockopt() and sometimes the other way around) + #[cfg(target_os = "espidf")] + { + fn esp_setsockopt( + socket: &mut UdpSocket, + proto: u32, + option: u32, + value: T, + ) -> Result<(), Error> { + use std::os::fd::AsRawFd; + + esp_idf_sys::esp!(unsafe { + esp_idf_sys::lwip_setsockopt( + socket.as_raw_fd(), + proto as _, + option as _, + &value as *const _ as *const _, + core::mem::size_of::() as _, + ) + }) + .map_err(|_| ErrorCode::StdIoError)?; + + Ok(()) + } + + let mreq = esp_idf_sys::ip_mreq { + imr_multiaddr: esp_idf_sys::in_addr { + s_addr: u32::from_ne_bytes(multiaddr.octets()), + }, + imr_interface: esp_idf_sys::in_addr { + s_addr: u32::from_ne_bytes(interface.octets()), + }, + }; + + esp_setsockopt( + &mut self.socket, + esp_idf_sys::IPPROTO_IP, + esp_idf_sys::IP_ADD_MEMBERSHIP, + mreq, + )?; } - info!("Joining multicast on {:?}", ip_addr); + info!("Joined IP multicast {}/{}", multiaddr, interface); Ok(()) } From 5b9fd502c7d70a8377ae856c4c25335c1ad4efbb Mon Sep 17 00:00:00 2001 From: ivmarkov Date: Tue, 13 Jun 2023 10:12:42 +0000 Subject: [PATCH 58/72] Fix the no_std build --- examples/onoff_light/src/main.rs | 41 ++++++++++++++++++++++++++++--- matter/src/crypto/crypto_dummy.rs | 2 +- matter/src/transport/udp.rs | 26 ++++++++++++++++++-- 3 files changed, 62 insertions(+), 7 deletions(-) diff --git a/examples/onoff_light/src/main.rs b/examples/onoff_light/src/main.rs index 7115a287..31c66b71 100644 --- a/examples/onoff_light/src/main.rs +++ b/examples/onoff_light/src/main.rs @@ -51,9 +51,10 @@ fn main() -> Result<(), Error> { thread.join().unwrap() } +// NOTE: For no_std, name this entry point according to your MCU platform #[cfg(not(feature = "std"))] #[no_mangle] -fn main() { +fn app_main() { run().unwrap(); } @@ -93,11 +94,27 @@ fn run() -> Result<(), Error> { let dev_att = dev_att::HardCodedDevAtt::new(); - let matter = Matter::new_default( + #[cfg(feature = "std")] + let epoch = matter::utils::epoch::sys_epoch; + + #[cfg(feature = "std")] + let rand = matter::utils::rand::sys_rand; + + // NOTE: For no_std, provide your own function here + #[cfg(not(feature = "std"))] + let epoch = matter::utils::epoch::dummy_epoch; + + // NOTE: For no_std, provide your own function here + #[cfg(not(feature = "std"))] + let rand = matter::utils::rand::dummy_rand; + + let matter = Matter::new( // vid/pid should match those in the DAC &dev_det, &dev_att, &mdns, + epoch, + rand, matter::MATTER_PORT, ); @@ -203,6 +220,10 @@ fn run() -> Result<(), Error> { let mut fut = pin!(async move { select(&mut io_fut, &mut mdns_fut).await.unwrap() }); + info!("Final future: {:p}", &mut fut); + + // NOTE: For no_std, replace with your own no_std way of polling the future + #[cfg(feature = "std")] smol::block_on(&mut fut)?; Ok::<_, matter::error::Error>(()) @@ -222,7 +243,19 @@ fn handler<'a>(matter: &'a Matter<'a>) -> impl Handler + 'a { ) } -#[cfg(not(target_os = "espidf"))] +// NOTE: For no_std, implement here your own way of initializing the logger +#[cfg(all(not(feature = "std"), not(target_os = "espidf")))] +#[inline(never)] +fn initialize_logger() {} + +// NOTE: For no_std, implement here your own way of initializing the network +#[cfg(all(not(feature = "std"), not(target_os = "espidf")))] +#[inline(never)] +fn initialize_network() -> Result<(Ipv4Addr, Ipv6Addr, u32), Error> { + Ok((Ipv4Addr::UNSPECIFIED, Ipv6Addr::UNSPECIFIED, 0)) +} + +#[cfg(all(feature = "std", not(target_os = "espidf")))] #[inline(never)] fn initialize_logger() { env_logger::init_from_env( @@ -230,7 +263,7 @@ fn initialize_logger() { ); } -#[cfg(not(target_os = "espidf"))] +#[cfg(all(feature = "std", not(target_os = "espidf")))] #[inline(never)] fn initialize_network() -> Result<(Ipv4Addr, Ipv6Addr, u32), Error> { use log::error; diff --git a/matter/src/crypto/crypto_dummy.rs b/matter/src/crypto/crypto_dummy.rs index 827b7f3e..7a84e645 100644 --- a/matter/src/crypto/crypto_dummy.rs +++ b/matter/src/crypto/crypto_dummy.rs @@ -27,7 +27,7 @@ pub fn hkdf_sha256(_salt: &[u8], _ikm: &[u8], _info: &[u8], _key: &mut [u8]) -> Ok(()) } -#[derive(Clone)] +#[derive(Clone, Debug)] pub struct Sha256 {} impl Sha256 { diff --git a/matter/src/transport/udp.rs b/matter/src/transport/udp.rs index bf3f36f5..9b234894 100644 --- a/matter/src/transport/udp.rs +++ b/matter/src/transport/udp.rs @@ -150,7 +150,7 @@ mod dummy_udp { use crate::error::*; use log::{debug, info}; - use crate::transport::network::SocketAddr; + use crate::transport::network::{Ipv4Addr, Ipv6Addr, SocketAddr}; pub struct UdpListener {} @@ -163,7 +163,29 @@ mod dummy_udp { Ok(listener) } - pub async fn join_multicast(&mut self, ip_addr: IpAddr) -> Result<(), Error> { + pub fn join_multicast_v6( + &mut self, + multiaddr: Ipv6Addr, + interface: u32, + ) -> Result<(), Error> { + info!( + "Pretending to join IPV6 multicast {}/{}", + multiaddr, interface + ); + + Ok(()) + } + + pub fn join_multicast_v4( + &mut self, + multiaddr: Ipv4Addr, + interface: Ipv4Addr, + ) -> Result<(), Error> { + info!( + "Pretending to join IP multicast {}/{}", + multiaddr, interface + ); + Ok(()) } From 879f816438bbb962dbe4ea60d870724c08e2c917 Mon Sep 17 00:00:00 2001 From: ivmarkov Date: Tue, 13 Jun 2023 10:38:02 +0000 Subject: [PATCH 59/72] More comments for tailoring the example for no_std --- examples/onoff_light/src/main.rs | 24 +++++++++++++++--------- 1 file changed, 15 insertions(+), 9 deletions(-) diff --git a/examples/onoff_light/src/main.rs b/examples/onoff_light/src/main.rs index 31c66b71..a5340f65 100644 --- a/examples/onoff_light/src/main.rs +++ b/examples/onoff_light/src/main.rs @@ -51,7 +51,7 @@ fn main() -> Result<(), Error> { thread.join().unwrap() } -// NOTE: For no_std, name this entry point according to your MCU platform +// NOTE (no_std): For no_std, name this entry point according to your MCU platform #[cfg(not(feature = "std"))] #[no_mangle] fn app_main() { @@ -100,11 +100,11 @@ fn run() -> Result<(), Error> { #[cfg(feature = "std")] let rand = matter::utils::rand::sys_rand; - // NOTE: For no_std, provide your own function here + // NOTE (no_std): For no_std, provide your own function here #[cfg(not(feature = "std"))] let epoch = matter::utils::epoch::dummy_epoch; - // NOTE: For no_std, provide your own function here + // NOTE (no_std): For no_std, provide your own function here #[cfg(not(feature = "std"))] let rand = matter::utils::rand::dummy_rand; @@ -175,6 +175,8 @@ fn run() -> Result<(), Error> { let tx_buf = &mut tx_buf; let mut io_fut = pin!(async move { + // NOTE (no_std): On no_std, the `UdpListener` implementation is a no-op so you might want to + // replace it with your own UDP stack let udp = UdpListener::new(SocketAddr::new( IpAddr::V6(Ipv6Addr::UNSPECIFIED), matter::MATTER_PORT, @@ -216,17 +218,21 @@ fn run() -> Result<(), Error> { Ok::<_, matter::error::Error>(()) }); + // NOTE (no_std): On no_std, the `run_udp` is a no-op so you might want to replace it with `run` and + // connect the pipes of the `run` method with your own UDP stack let mut mdns_fut = pin!(async move { mdns_runner.run_udp().await }); let mut fut = pin!(async move { select(&mut io_fut, &mut mdns_fut).await.unwrap() }); - info!("Final future: {:p}", &mut fut); - - // NOTE: For no_std, replace with your own no_std way of polling the future #[cfg(feature = "std")] smol::block_on(&mut fut)?; - Ok::<_, matter::error::Error>(()) + // NOTE (no_std): For no_std, replace with your own more efficient no_std executor, + // because the executor used below is a simple busy-loop poller + #[cfg(not(feature = "std"))] + embassy_futures::block_on(&mut fut)?; + + Ok(()) } fn handler<'a>(matter: &'a Matter<'a>) -> impl Handler + 'a { @@ -243,12 +249,12 @@ fn handler<'a>(matter: &'a Matter<'a>) -> impl Handler + 'a { ) } -// NOTE: For no_std, implement here your own way of initializing the logger +// NOTE (no_std): For no_std, implement here your own way of initializing the logger #[cfg(all(not(feature = "std"), not(target_os = "espidf")))] #[inline(never)] fn initialize_logger() {} -// NOTE: For no_std, implement here your own way of initializing the network +// NOTE (no_std): For no_std, implement here your own way of initializing the network #[cfg(all(not(feature = "std"), not(target_os = "espidf")))] #[inline(never)] fn initialize_network() -> Result<(Ipv4Addr, Ipv6Addr, u32), Error> { From 831853630bd2d4d911f963822283958066fba15c Mon Sep 17 00:00:00 2001 From: ivmarkov Date: Fri, 16 Jun 2023 18:42:11 +0000 Subject: [PATCH 60/72] Add from/to TLV for i16, i32 and i64 --- matter/src/tlv/parser.rs | 27 +++++++++++++++++++++++++++ matter/src/tlv/traits.rs | 4 ++-- 2 files changed, 29 insertions(+), 2 deletions(-) diff --git a/matter/src/tlv/parser.rs b/matter/src/tlv/parser.rs index 8bfdd28a..5e6964c3 100644 --- a/matter/src/tlv/parser.rs +++ b/matter/src/tlv/parser.rs @@ -357,6 +357,14 @@ impl<'a> TLVElement<'a> { } } + pub fn i16(&self) -> Result { + match self.element_type { + ElementType::S8(a) => Ok(a.into()), + ElementType::S16(a) => Ok(a), + _ => Err(ErrorCode::TLVTypeMismatch.into()), + } + } + pub fn u16(&self) -> Result { match self.element_type { ElementType::U8(a) => Ok(a.into()), @@ -365,6 +373,15 @@ impl<'a> TLVElement<'a> { } } + pub fn i32(&self) -> Result { + match self.element_type { + ElementType::S8(a) => Ok(a.into()), + ElementType::S16(a) => Ok(a.into()), + ElementType::S32(a) => Ok(a), + _ => Err(ErrorCode::TLVTypeMismatch.into()), + } + } + pub fn u32(&self) -> Result { match self.element_type { ElementType::U8(a) => Ok(a.into()), @@ -374,6 +391,16 @@ impl<'a> TLVElement<'a> { } } + pub fn i64(&self) -> Result { + match self.element_type { + ElementType::S8(a) => Ok(a.into()), + ElementType::S16(a) => Ok(a.into()), + ElementType::S32(a) => Ok(a.into()), + ElementType::S64(a) => Ok(a), + _ => Err(ErrorCode::TLVTypeMismatch.into()), + } + } + pub fn u64(&self) -> Result { match self.element_type { ElementType::U8(a) => Ok(a.into()), diff --git a/matter/src/tlv/traits.rs b/matter/src/tlv/traits.rs index 9a8edcdd..2156f585 100644 --- a/matter/src/tlv/traits.rs +++ b/matter/src/tlv/traits.rs @@ -91,7 +91,7 @@ macro_rules! fromtlv_for { }; } -fromtlv_for!(u8 u16 u32 u64 bool); +fromtlv_for!(i8 u8 i16 u16 i32 u32 i64 u64 bool); pub trait ToTLV { fn to_tlv(&self, tw: &mut TLVWriter, tag: TagType) -> Result<(), Error>; @@ -139,7 +139,7 @@ impl<'a, T: ToTLV> ToTLV for &'a [T] { } // Generate ToTLV for standard data types -totlv_for!(i8 u8 u16 u32 u64 bool); +totlv_for!(i8 u8 i16 u16 i32 u32 i64 u64 bool); // We define a few common data types that will be required here // From 44e01a5881b1562a503f4f87c8f2cb2b2d58439d Mon Sep 17 00:00:00 2001 From: ivmarkov Date: Sat, 17 Jun 2023 14:00:44 +0000 Subject: [PATCH 61/72] Configurable parts_list in descriptor --- matter/src/data_model/root_endpoint.rs | 2 +- .../src/data_model/system_model/descriptor.rs | 69 ++++++++++++++++--- matter/tests/common/im_engine.rs | 2 +- 3 files changed, 60 insertions(+), 13 deletions(-) diff --git a/matter/src/data_model/root_endpoint.rs b/matter/src/data_model/root_endpoint.rs index 78b8cfb9..21691cdf 100644 --- a/matter/src/data_model/root_endpoint.rs +++ b/matter/src/data_model/root_endpoint.rs @@ -27,7 +27,7 @@ use super::{ }; pub type RootEndpointHandler<'a> = handler_chain_type!( - DescriptorCluster, + DescriptorCluster<'static>, BasicInfoCluster<'a>, GenCommCluster<'a>, NwCommCluster, diff --git a/matter/src/data_model/system_model/descriptor.rs b/matter/src/data_model/system_model/descriptor.rs index b434586d..00e1191a 100644 --- a/matter/src/data_model/system_model/descriptor.rs +++ b/matter/src/data_model/system_model/descriptor.rs @@ -53,13 +53,63 @@ pub const CLUSTER: Cluster<'static> = Cluster { commands: &[], }; -pub struct DescriptorCluster { +struct StandardPartsMatcher; + +impl PartsMatcher for StandardPartsMatcher { + fn describe(&self, our_endpoint: EndptId, endpoint: EndptId) -> bool { + our_endpoint == 0 && endpoint != our_endpoint + } +} + +struct AggregatorPartsMatcher; + +impl PartsMatcher for AggregatorPartsMatcher { + fn describe(&self, our_endpoint: EndptId, endpoint: EndptId) -> bool { + endpoint != our_endpoint && endpoint != 0 + } +} + +pub trait PartsMatcher { + fn describe(&self, our_endpoint: EndptId, endpoint: EndptId) -> bool; +} + +impl PartsMatcher for &T +where + T: PartsMatcher, +{ + fn describe(&self, our_endpoint: EndptId, endpoint: EndptId) -> bool { + (**self).describe(our_endpoint, endpoint) + } +} + +impl PartsMatcher for &mut T +where + T: PartsMatcher, +{ + fn describe(&self, our_endpoint: EndptId, endpoint: EndptId) -> bool { + (**self).describe(our_endpoint, endpoint) + } +} + +pub struct DescriptorCluster<'a> { + matcher: &'a dyn PartsMatcher, data_ver: Dataver, } -impl DescriptorCluster { +impl DescriptorCluster<'static> { pub fn new(rand: Rand) -> Self { + Self::new_matching(&StandardPartsMatcher, rand) + } + + pub fn new_aggregator(rand: Rand) -> Self { + Self::new_matching(&AggregatorPartsMatcher, rand) + } +} + +impl<'a> DescriptorCluster<'a> { + pub fn new_matching(matcher: &'a dyn PartsMatcher, rand: Rand) -> DescriptorCluster<'a> { Self { + matcher, data_ver: Dataver::new(rand), } } @@ -159,12 +209,9 @@ impl DescriptorCluster { ) -> Result<(), Error> { tw.start_array(tag)?; - if endpoint_id == 0 { - // TODO: If endpoint is another than 0, need to figure out what to do - for endpoint in node.endpoints { - if endpoint.id != 0 { - tw.u16(TagType::Anonymous, endpoint.id)?; - } + for endpoint in node.endpoints { + if self.matcher.describe(endpoint_id, endpoint.id) { + tw.u16(TagType::Anonymous, endpoint.id)?; } } @@ -184,15 +231,15 @@ impl DescriptorCluster { } } -impl Handler for DescriptorCluster { +impl<'a> Handler for DescriptorCluster<'a> { fn read(&self, attr: &AttrDetails, encoder: AttrDataEncoder) -> Result<(), Error> { DescriptorCluster::read(self, attr, encoder) } } -impl NonBlockingHandler for DescriptorCluster {} +impl<'a> NonBlockingHandler for DescriptorCluster<'a> {} -impl ChangeNotifier<()> for DescriptorCluster { +impl<'a> ChangeNotifier<()> for DescriptorCluster<'a> { fn consume_change(&mut self) -> Option<()> { self.data_ver.consume_change(()) } diff --git a/matter/tests/common/im_engine.rs b/matter/tests/common/im_engine.rs index 13da8cd5..166b7fc8 100644 --- a/matter/tests/common/im_engine.rs +++ b/matter/tests/common/im_engine.rs @@ -101,7 +101,7 @@ impl<'a> ImInput<'a> { } } -pub type DmHandler<'a> = handler_chain_type!(OnOffCluster, EchoCluster, DescriptorCluster, EchoCluster | RootEndpointHandler<'a>); +pub type DmHandler<'a> = handler_chain_type!(OnOffCluster, EchoCluster, DescriptorCluster<'a>, EchoCluster | RootEndpointHandler<'a>); pub fn matter(mdns: &mut dyn Mdns) -> Matter<'_> { #[cfg(feature = "std")] From 7f9ccbc38d3a64da787e18b6c4c6d32641bfaecd Mon Sep 17 00:00:00 2001 From: ivmarkov Date: Sat, 10 Jun 2023 14:01:35 +0000 Subject: [PATCH 62/72] Sequential Exchange API --- examples/onoff_light/src/main.rs | 188 ++-- matter/src/data_model/cluster_on_off.rs | 36 +- matter/src/data_model/core.rs | 348 ++----- matter/src/data_model/objects/dataver.rs | 24 +- matter/src/data_model/objects/encoder.rs | 172 ++-- matter/src/data_model/objects/handler.rs | 222 ++++- matter/src/data_model/objects/metadata.rs | 178 ++++ matter/src/data_model/objects/mod.rs | 3 + matter/src/data_model/objects/node.rs | 123 +-- matter/src/data_model/root_endpoint.rs | 2 +- .../src/data_model/sdm/admin_commissioning.rs | 14 +- .../data_model/sdm/general_commissioning.rs | 42 +- matter/src/data_model/sdm/noc.rs | 145 +-- .../data_model/system_model/access_control.rs | 12 +- matter/src/interaction_model/core.rs | 893 ++++++++---------- matter/src/lib.rs | 20 + matter/src/secure_channel/case.rs | 202 ++-- matter/src/secure_channel/common.rs | 16 +- matter/src/secure_channel/core.rs | 67 +- matter/src/secure_channel/pake.rs | 356 +++---- matter/src/transport/core.rs | 544 +++++++---- matter/src/transport/exchange.rs | 777 +++++---------- matter/src/transport/mod.rs | 2 +- matter/src/transport/proto_ctx.rs | 41 - matter/src/transport/runner.rs | 392 ++++++++ matter/tests/common/echo_cluster.rs | 53 +- matter/tests/common/handlers.rs | 278 +++--- matter/tests/common/im_engine.rs | 444 ++++++--- matter/tests/data_model/acl_and_dataver.rs | 182 ++-- matter/tests/data_model/attribute_lists.rs | 15 +- matter/tests/data_model/attributes.rs | 53 +- matter/tests/data_model/commands.rs | 16 +- matter/tests/data_model/long_reads.rs | 135 ++- matter/tests/data_model/timed_requests.rs | 36 +- matter/tests/interaction_model.rs | 152 --- sdkconfig.defaults | 6 + 36 files changed, 3200 insertions(+), 2989 deletions(-) create mode 100644 matter/src/data_model/objects/metadata.rs delete mode 100644 matter/src/transport/proto_ctx.rs create mode 100644 matter/src/transport/runner.rs delete mode 100644 matter/tests/interaction_model.rs create mode 100644 sdkconfig.defaults diff --git a/examples/onoff_light/src/main.rs b/examples/onoff_light/src/main.rs index a5340f65..ecfc71eb 100644 --- a/examples/onoff_light/src/main.rs +++ b/examples/onoff_light/src/main.rs @@ -23,20 +23,15 @@ use log::info; use matter::core::{CommissioningData, Matter}; use matter::data_model::cluster_basic_information::BasicInfoConfig; use matter::data_model::cluster_on_off; -use matter::data_model::core::DataModel; use matter::data_model::device_types::DEV_TYPE_ON_OFF_LIGHT; use matter::data_model::objects::*; use matter::data_model::root_endpoint; use matter::data_model::system_model::descriptor; use matter::error::Error; -use matter::interaction_model::core::InteractionModel; use matter::mdns::{DefaultMdns, DefaultMdnsRunner}; use matter::secure_channel::spake2p::VerifierData; -use matter::transport::network::{Address, IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr}; -use matter::transport::{ - core::RecvAction, core::Transport, packet::MAX_RX_BUF_SIZE, packet::MAX_TX_BUF_SIZE, - udp::UdpListener, -}; +use matter::transport::network::{Ipv4Addr, Ipv6Addr}; +use matter::transport::runner::{RxBuf, TransportRunner, TxBuf}; use matter::utils::select::EitherUnwrap; mod dev_att; @@ -44,7 +39,7 @@ mod dev_att; #[cfg(feature = "std")] fn main() -> Result<(), Error> { let thread = std::thread::Builder::new() - .stack_size(120 * 1024) + .stack_size(140 * 1024) .spawn(run) .unwrap(); @@ -62,10 +57,10 @@ fn run() -> Result<(), Error> { initialize_logger(); info!( - "Matter memory: mDNS={}, Matter={}, Transport={}", + "Matter memory: mDNS={}, Matter={}, TransportRunner={}", core::mem::size_of::(), core::mem::size_of::(), - core::mem::size_of::(), + core::mem::size_of::(), ); let dev_det = BasicInfoConfig { @@ -92,6 +87,8 @@ fn run() -> Result<(), Error> { let mut mdns_runner = DefaultMdnsRunner::new(&mdns); + info!("mDNS initialized: {:p}, {:p}", &mdns, &mdns_runner); + let dev_att = dev_att::HardCodedDevAtt::new(); #[cfg(feature = "std")] @@ -118,36 +115,25 @@ fn run() -> Result<(), Error> { matter::MATTER_PORT, ); - let psm_path = std::env::temp_dir().join("matter-iot"); - info!("Persisting from/to {}", psm_path.display()); - - #[cfg(all(feature = "std", not(target_os = "espidf")))] - let psm = matter::persist::FilePsm::new(psm_path)?; + info!("Matter initialized: {:p}", &matter); - let mut buf = [0; 4096]; - let buf = &mut buf; + let mut runner = TransportRunner::new(&matter); - #[cfg(all(feature = "std", not(target_os = "espidf")))] - { - if let Some(data) = psm.load("acls", buf)? { - matter.load_acls(data)?; - } + info!("Transport Runner initialized: {:p}", &runner); - if let Some(data) = psm.load("fabrics", buf)? { - matter.load_fabrics(data)?; - } - } + let mut tx_buf = TxBuf::uninit(); + let mut rx_buf = RxBuf::uninit(); - let mut transport = Transport::new(&matter); + // #[cfg(all(feature = "std", not(target_os = "espidf")))] + // { + // if let Some(data) = psm.load("acls", buf)? { + // matter.load_acls(data)?; + // } - transport.start( - CommissioningData { - // TODO: Hard-coded for now - verifier: VerifierData::new_with_pw(123456, *matter.borrow()), - discriminator: 250, - }, - buf, - )?; + // if let Some(data) = psm.load("fabrics", buf)? { + // matter.load_fabrics(data)?; + // } + // } let node = Node { id: 0, @@ -161,69 +147,48 @@ fn run() -> Result<(), Error> { ], }; - let mut handler = handler(&matter); + let handler = HandlerCompat(handler(&matter)); - let mut im = InteractionModel(DataModel::new(matter.borrow(), &node, &mut handler)); - - let mut rx_buf = [0; MAX_RX_BUF_SIZE]; - let mut tx_buf = [0; MAX_TX_BUF_SIZE]; - - let im = &mut im; - let mdns_runner = &mut mdns_runner; - let transport = &mut transport; - let rx_buf = &mut rx_buf; + let matter = &matter; + let node = &node; + let handler = &handler; + let runner = &mut runner; let tx_buf = &mut tx_buf; + let rx_buf = &mut rx_buf; - let mut io_fut = pin!(async move { - // NOTE (no_std): On no_std, the `UdpListener` implementation is a no-op so you might want to - // replace it with your own UDP stack - let udp = UdpListener::new(SocketAddr::new( - IpAddr::V6(Ipv6Addr::UNSPECIFIED), - matter::MATTER_PORT, - )) - .await?; - - loop { - let (len, addr) = udp.recv(rx_buf).await?; - - let mut completion = transport.recv(Address::Udp(addr), &mut rx_buf[..len], tx_buf); - - while let Some(action) = completion.next_action()? { - match action { - RecvAction::Send(addr, buf) => { - udp.send(addr.unwrap_udp(), buf).await?; - } - RecvAction::Interact(mut ctx) => { - if im.handle(&mut ctx)? && ctx.send()? { - udp.send(ctx.tx.peer.unwrap_udp(), ctx.tx.as_slice()) - .await?; - } - } - } - } - - #[cfg(all(feature = "std", not(target_os = "espidf")))] - { - if let Some(data) = transport.matter().store_fabrics(buf)? { - psm.store("fabrics", data)?; - } - - if let Some(data) = transport.matter().store_acls(buf)? { - psm.store("acls", data)?; - } - } - } - - #[allow(unreachable_code)] - Ok::<_, matter::error::Error>(()) - }); + info!( + "About to run wth node {:p}, handler {:p}, transport runner {:p}, mdns_runner {:p}", + node, handler, runner, &mdns_runner + ); + + let mut fut = pin!(async move { + // NOTE (no_std): On no_std, the `run_udp` is a no-op so you might want to replace it with `run` and + // connect the pipes of the `run` method with your own UDP stack + let mut transport = pin!(runner.run_udp( + tx_buf, + rx_buf, + CommissioningData { + // TODO: Hard-coded for now + verifier: VerifierData::new_with_pw(123456, *matter.borrow()), + discriminator: 250, + }, + &handler, + )); - // NOTE (no_std): On no_std, the `run_udp` is a no-op so you might want to replace it with `run` and - // connect the pipes of the `run` method with your own UDP stack - let mut mdns_fut = pin!(async move { mdns_runner.run_udp().await }); + // NOTE (no_std): On no_std, the `run_udp` is a no-op so you might want to replace it with `run` and + // connect the pipes of the `run` method with your own UDP stack + let mut mdns = pin!(mdns_runner.run_udp()); - let mut fut = pin!(async move { select(&mut io_fut, &mut mdns_fut).await.unwrap() }); + select( + &mut transport, + &mut mdns, + //save(transport, &psm), + ) + .await + .unwrap() + }); + // NOTE: For no_std, replace with your own no_std way of polling the future #[cfg(feature = "std")] smol::block_on(&mut fut)?; @@ -235,18 +200,33 @@ fn run() -> Result<(), Error> { Ok(()) } -fn handler<'a>(matter: &'a Matter<'a>) -> impl Handler + 'a { - root_endpoint::handler(0, matter) - .chain( - 1, - descriptor::ID, - descriptor::DescriptorCluster::new(*matter.borrow()), - ) - .chain( - 1, - cluster_on_off::ID, - cluster_on_off::OnOffCluster::new(*matter.borrow()), - ) +const NODE: Node<'static> = Node { + id: 0, + endpoints: &[ + root_endpoint::endpoint(0), + Endpoint { + id: 1, + device_type: DEV_TYPE_ON_OFF_LIGHT, + clusters: &[descriptor::CLUSTER, cluster_on_off::CLUSTER], + }, + ], +}; + +fn handler<'a>(matter: &'a Matter<'a>) -> impl Metadata + NonBlockingHandler + 'a { + ( + NODE, + root_endpoint::handler(0, matter) + .chain( + 1, + descriptor::ID, + descriptor::DescriptorCluster::new(*matter.borrow()), + ) + .chain( + 1, + cluster_on_off::ID, + cluster_on_off::OnOffCluster::new(*matter.borrow()), + ), + ) } // NOTE (no_std): For no_std, implement here your own way of initializing the logger diff --git a/matter/src/data_model/cluster_on_off.rs b/matter/src/data_model/cluster_on_off.rs index 1a26522a..8d03d9b9 100644 --- a/matter/src/data_model/cluster_on_off.rs +++ b/matter/src/data_model/cluster_on_off.rs @@ -15,12 +15,12 @@ * limitations under the License. */ -use core::convert::TryInto; +use core::{cell::Cell, convert::TryInto}; use super::objects::*; use crate::{ - attribute_enum, cmd_enter, command_enum, error::Error, interaction_model::core::Transaction, - tlv::TLVElement, utils::rand::Rand, + attribute_enum, cmd_enter, command_enum, error::Error, tlv::TLVElement, + transport::exchange::Exchange, utils::rand::Rand, }; use log::info; use strum::{EnumDiscriminants, FromRepr}; @@ -66,20 +66,20 @@ pub const CLUSTER: Cluster<'static> = Cluster { pub struct OnOffCluster { data_ver: Dataver, - on: bool, + on: Cell, } impl OnOffCluster { pub fn new(rand: Rand) -> Self { Self { data_ver: Dataver::new(rand), - on: false, + on: Cell::new(false), } } - pub fn set(&mut self, on: bool) { - if self.on != on { - self.on = on; + pub fn set(&self, on: bool) { + if self.on.get() != on { + self.on.set(on); self.data_ver.changed(); } } @@ -90,7 +90,7 @@ impl OnOffCluster { CLUSTER.read(attr.attr_id, writer) } else { match attr.attr_id.try_into()? { - Attributes::OnOff(codec) => codec.encode(writer, self.on), + Attributes::OnOff(codec) => codec.encode(writer, self.on.get()), } } } else { @@ -98,7 +98,7 @@ impl OnOffCluster { } } - pub fn write(&mut self, attr: &AttrDetails, data: AttrData) -> Result<(), Error> { + pub fn write(&self, attr: &AttrDetails, data: AttrData) -> Result<(), Error> { let data = data.with_dataver(self.data_ver.get())?; match attr.attr_id.try_into()? { @@ -111,8 +111,8 @@ impl OnOffCluster { } pub fn invoke( - &mut self, - transaction: &mut Transaction, + &self, + _exchange: &Exchange, cmd: &CmdDetails, _data: &TLVElement, _encoder: CmdDataEncoder, @@ -128,12 +128,10 @@ impl OnOffCluster { } Commands::Toggle => { cmd_enter!("Toggle"); - self.set(!self.on); + self.set(!self.on.get()); } } - transaction.complete(); - self.data_ver.changed(); Ok(()) @@ -145,18 +143,18 @@ impl Handler for OnOffCluster { OnOffCluster::read(self, attr, encoder) } - fn write(&mut self, attr: &AttrDetails, data: AttrData) -> Result<(), Error> { + fn write(&self, attr: &AttrDetails, data: AttrData) -> Result<(), Error> { OnOffCluster::write(self, attr, data) } fn invoke( - &mut self, - transaction: &mut Transaction, + &self, + exchange: &Exchange, cmd: &CmdDetails, data: &TLVElement, encoder: CmdDataEncoder, ) -> Result<(), Error> { - OnOffCluster::invoke(self, transaction, cmd, data, encoder) + OnOffCluster::invoke(self, exchange, cmd, data, encoder) } } diff --git a/matter/src/data_model/core.rs b/matter/src/data_model/core.rs index 20efeb77..69935c5c 100644 --- a/matter/src/data_model/core.rs +++ b/matter/src/data_model/core.rs @@ -15,287 +15,127 @@ * limitations under the License. */ -use core::cell::RefCell; +use core::sync::atomic::{AtomicU32, Ordering}; use super::objects::*; use crate::{ - acl::{Accessor, AclMgr}, + alloc, error::*, - interaction_model::core::{Interaction, Transaction}, - tlv::TLVWriter, - transport::packet::Packet, + interaction_model::core::Interaction, + transport::{exchange::Exchange, packet::Packet}, }; -pub struct DataModel<'a, T> { - pub acl_mgr: &'a RefCell, - pub node: &'a Node<'a>, - pub handler: T, -} +// TODO: For now... +static SUBS_ID: AtomicU32 = AtomicU32::new(1); -impl<'a, T> DataModel<'a, T> { - pub const fn new(acl_mgr: &'a RefCell, node: &'a Node<'a>, handler: T) -> Self { - Self { - acl_mgr, - node, - handler, - } +pub struct DataModel(T); + +impl DataModel { + pub fn new(handler: T) -> Self { + Self(handler) } - pub fn handle( - &mut self, - interaction: Interaction, - tx: &mut Packet, - transaction: &mut Transaction, - ) -> Result + pub async fn handle<'r, 'p>( + &self, + exchange: &'r mut Exchange<'_>, + rx: &'r mut Packet<'p>, + tx: &'r mut Packet<'p>, + rx_status: &'r mut Packet<'p>, + ) -> Result<(), Error> where - T: Handler, + T: DataModelHandler, { - let accessor = Accessor::for_session(transaction.session(), self.acl_mgr); - let mut tw = TLVWriter::new(tx.get_writebuf()?); - - match interaction { - Interaction::Read(req) => { - let mut resume_path = None; - - for item in self.node.read(&req, &accessor) { - if let Some(path) = AttrDataEncoder::handle_read(item, &self.handler, &mut tw)? - { - resume_path = Some(path); - break; + let timeout = Interaction::timeout(exchange, rx, tx).await?; + + let mut interaction = alloc!(Interaction::new( + exchange, + rx, + tx, + rx_status, + || SUBS_ID.fetch_add(1, Ordering::SeqCst), + timeout, + )?); + + #[cfg(feature = "alloc")] + let interaction = &mut *interaction; + + #[cfg(not(feature = "alloc"))] + let interaction = &mut interaction; + + #[cfg(feature = "nightly")] + let metadata = self.0.lock().await; + + #[cfg(not(feature = "nightly"))] + let metadata = self.0.lock(); + + if interaction.start().await? { + match interaction { + Interaction::Read { + req, + ref mut driver, + } => { + let accessor = driver.accessor()?; + + 'outer: for item in metadata.node().read(req, None, &accessor) { + while !AttrDataEncoder::handle_read(&item, &self.0, &mut driver.writer()?) + .await? + { + if !driver.send_chunk(req).await? { + break 'outer; + } + } } - } - - req.complete(tx, transaction, resume_path) - } - Interaction::Write(req) => { - for item in self.node.write(&req, &accessor) { - AttrDataEncoder::handle_write(item, &mut self.handler, &mut tw)?; - } - req.complete(tx, transaction) - } - Interaction::Invoke(req) => { - for item in self.node.invoke(&req, &accessor) { - CmdDataEncoder::handle(item, &mut self.handler, transaction, &mut tw)?; + driver.complete(req).await?; } - - req.complete(tx, transaction) - } - Interaction::Subscribe(req) => { - let mut resume_path = None; - - for item in self.node.subscribing_read(&req, &accessor) { - if let Some(path) = AttrDataEncoder::handle_read(item, &self.handler, &mut tw)? - { - resume_path = Some(path); - break; + Interaction::Write { + req, + ref mut driver, + } => { + let accessor = driver.accessor()?; + + for item in metadata.node().write(req, &accessor) { + AttrDataEncoder::handle_write(&item, &self.0, &mut driver.writer()?) + .await?; } - } - req.complete(tx, transaction, resume_path) - } - Interaction::Timed(_) => Ok(false), - Interaction::ResumeRead(req) => { - let mut resume_path = None; - - for item in self.node.resume_read(&req, &accessor) { - if let Some(path) = AttrDataEncoder::handle_read(item, &self.handler, &mut tw)? - { - resume_path = Some(path); - break; - } + driver.complete(req).await?; } + Interaction::Invoke { + req, + ref mut driver, + } => { + let accessor = driver.accessor()?; - req.complete(tx, transaction, resume_path) - } - Interaction::ResumeSubscribe(req) => { - let mut resume_path = None; + for item in metadata.node().invoke(req, &accessor) { + let (mut tw, exchange) = driver.writer_exchange()?; - for item in self.node.resume_subscribing_read(&req, &accessor) { - if let Some(path) = AttrDataEncoder::handle_read(item, &self.handler, &mut tw)? - { - resume_path = Some(path); - break; + CmdDataEncoder::handle(&item, &self.0, &mut tw, exchange).await?; } - } - - req.complete(tx, transaction, resume_path) - } - } - } - - #[cfg(feature = "nightly")] - pub async fn handle_async<'p>( - &mut self, - interaction: Interaction<'_>, - tx: &'p mut Packet<'_>, - transaction: &mut Transaction<'_, '_>, - ) -> Result - where - T: super::objects::asynch::AsyncHandler, - { - let accessor = Accessor::for_session(transaction.session(), self.acl_mgr); - let mut tw = TLVWriter::new(tx.get_writebuf()?); - - match interaction { - Interaction::Read(req) => { - let mut resume_path = None; - for item in self.node.read(&req, &accessor) { - if let Some(path) = - AttrDataEncoder::handle_read_async(item, &self.handler, &mut tw).await? - { - resume_path = Some(path); - break; - } - } - - req.complete(tx, transaction, resume_path) - } - Interaction::Write(req) => { - for item in self.node.write(&req, &accessor) { - AttrDataEncoder::handle_write_async(item, &mut self.handler, &mut tw).await?; + driver.complete(req).await?; } - - req.complete(tx, transaction) - } - Interaction::Invoke(req) => { - for item in self.node.invoke(&req, &accessor) { - CmdDataEncoder::handle_async(item, &mut self.handler, transaction, &mut tw) - .await?; - } - - req.complete(tx, transaction) - } - Interaction::Subscribe(req) => { - let mut resume_path = None; - - for item in self.node.subscribing_read(&req, &accessor) { - if let Some(path) = - AttrDataEncoder::handle_read_async(item, &self.handler, &mut tw).await? - { - resume_path = Some(path); - break; + Interaction::Subscribe { + req, + ref mut driver, + } => { + let accessor = driver.accessor()?; + + 'outer: for item in metadata.node().subscribing_read(req, None, &accessor) { + while !AttrDataEncoder::handle_read(&item, &self.0, &mut driver.writer()?) + .await? + { + if !driver.send_chunk(req).await? { + break 'outer; + } + } } - } - - req.complete(tx, transaction, resume_path) - } - Interaction::Timed(_) => Ok(false), - Interaction::ResumeRead(req) => { - let mut resume_path = None; - for item in self.node.resume_read(&req, &accessor) { - if let Some(path) = - AttrDataEncoder::handle_read_async(item, &self.handler, &mut tw).await? - { - resume_path = Some(path); - break; - } + driver.complete(req).await?; } - - req.complete(tx, transaction, resume_path) } - Interaction::ResumeSubscribe(req) => { - let mut resume_path = None; - - for item in self.node.resume_subscribing_read(&req, &accessor) { - if let Some(path) = - AttrDataEncoder::handle_read_async(item, &self.handler, &mut tw).await? - { - resume_path = Some(path); - break; - } - } - - req.complete(tx, transaction, resume_path) - } - } - } -} - -pub trait DataHandler { - fn handle( - &mut self, - interaction: Interaction, - tx: &mut Packet, - transaction: &mut Transaction, - ) -> Result; -} - -impl DataHandler for &mut T -where - T: DataHandler, -{ - fn handle( - &mut self, - interaction: Interaction, - tx: &mut Packet, - transaction: &mut Transaction, - ) -> Result { - (**self).handle(interaction, tx, transaction) - } -} - -impl<'a, T> DataHandler for DataModel<'a, T> -where - T: Handler, -{ - fn handle( - &mut self, - interaction: Interaction, - tx: &mut Packet, - transaction: &mut Transaction, - ) -> Result { - DataModel::handle(self, interaction, tx, transaction) - } -} - -#[cfg(feature = "nightly")] -pub mod asynch { - use crate::{ - data_model::objects::asynch::AsyncHandler, - error::Error, - interaction_model::core::{Interaction, Transaction}, - transport::packet::Packet, - }; - - use super::DataModel; - - pub trait AsyncDataHandler { - async fn handle( - &mut self, - interaction: Interaction<'_>, - tx: &mut Packet, - transaction: &mut Transaction, - ) -> Result; - } - - impl AsyncDataHandler for &mut T - where - T: AsyncDataHandler, - { - async fn handle( - &mut self, - interaction: Interaction<'_>, - tx: &mut Packet<'_>, - transaction: &mut Transaction<'_, '_>, - ) -> Result { - (**self).handle(interaction, tx, transaction).await } - } - impl<'a, T> AsyncDataHandler for DataModel<'a, T> - where - T: AsyncHandler, - { - async fn handle( - &mut self, - interaction: Interaction<'_>, - tx: &mut Packet<'_>, - transaction: &mut Transaction<'_, '_>, - ) -> Result { - DataModel::handle_async(self, interaction, tx, transaction).await - } + Ok(()) } } diff --git a/matter/src/data_model/objects/dataver.rs b/matter/src/data_model/objects/dataver.rs index f05a3838..dcdb42d3 100644 --- a/matter/src/data_model/objects/dataver.rs +++ b/matter/src/data_model/objects/dataver.rs @@ -15,11 +15,13 @@ * limitations under the License. */ +use core::cell::Cell; + use crate::utils::rand::Rand; pub struct Dataver { - ver: u32, - changed: bool, + ver: Cell, + changed: Cell, } impl Dataver { @@ -28,25 +30,25 @@ impl Dataver { rand(&mut buf); Self { - ver: u32::from_be_bytes(buf), - changed: false, + ver: Cell::new(u32::from_be_bytes(buf)), + changed: Cell::new(false), } } pub fn get(&self) -> u32 { - self.ver + self.ver.get() } - pub fn changed(&mut self) -> u32 { - (self.ver, _) = self.ver.overflowing_add(1); - self.changed = true; + pub fn changed(&self) -> u32 { + self.ver.set(self.ver.get().overflowing_add(1).0); + self.changed.set(true); self.get() } - pub fn consume_change(&mut self, change: T) -> Option { - if self.changed { - self.changed = false; + pub fn consume_change(&self, change: T) -> Option { + if self.changed.get() { + self.changed.set(false); Some(change) } else { None diff --git a/matter/src/data_model/objects/encoder.rs b/matter/src/data_model/objects/encoder.rs index 70e0db76..73f610b2 100644 --- a/matter/src/data_model/objects/encoder.rs +++ b/matter/src/data_model/objects/encoder.rs @@ -19,12 +19,12 @@ use core::fmt::{Debug, Formatter}; use core::marker::PhantomData; use core::ops::{Deref, DerefMut}; -use crate::interaction_model::core::{IMStatusCode, Transaction}; +use crate::interaction_model::core::IMStatusCode; use crate::interaction_model::messages::ib::{ AttrPath, AttrResp, AttrStatus, CmdDataTag, CmdPath, CmdStatus, InvResp, InvRespTag, }; -use crate::interaction_model::messages::GenericPath; use crate::tlv::UtfStr; +use crate::transport::exchange::Exchange; use crate::{ error::{Error, ErrorCode}, interaction_model::messages::ib::{AttrDataTag, AttrRespTag}, @@ -32,7 +32,7 @@ use crate::{ }; use log::error; -use super::{AttrDetails, CmdDetails, Handler}; +use super::{AttrDetails, CmdDetails, DataModelHandler}; // TODO: Should this return an IMStatusCode Error? But if yes, the higher layer // may have already started encoding the 'success' headers, we might not want to manage @@ -124,106 +124,79 @@ pub struct AttrDataEncoder<'a, 'b, 'c> { } impl<'a, 'b, 'c> AttrDataEncoder<'a, 'b, 'c> { - pub fn handle_read( - item: Result, + pub async fn handle_read( + item: &Result, AttrStatus>, handler: &T, - tw: &mut TLVWriter, - ) -> Result, Error> { + tw: &mut TLVWriter<'_, '_>, + ) -> Result { let status = match item { Ok(attr) => { - let encoder = AttrDataEncoder::new(&attr, tw); + let encoder = AttrDataEncoder::new(attr, tw); + + let result = { + #[cfg(not(feature = "nightly"))] + { + handler.read(attr, encoder) + } + + #[cfg(feature = "nightly")] + { + handler.read(&attr, encoder).await + } + }; - match handler.read(&attr, encoder) { + match result { Ok(()) => None, Err(e) => { if e.code() == ErrorCode::NoSpace { - return Ok(Some(attr.path().to_gp())); + return Ok(false); } else { attr.status(e.into())? } } } } - Err(status) => Some(status), + Err(status) => Some(status.clone()), }; if let Some(status) = status { AttrResp::Status(status).to_tlv(tw, TagType::Anonymous)?; } - Ok(None) - } - - pub fn handle_write( - item: Result<(AttrDetails, TLVElement), AttrStatus>, - handler: &mut T, - tw: &mut TLVWriter, - ) -> Result<(), Error> { - let status = match item { - Ok((attr, data)) => match handler.write(&attr, AttrData::new(attr.dataver, &data)) { - Ok(()) => attr.status(IMStatusCode::Success)?, - Err(error) => attr.status(error.into())?, - }, - Err(status) => Some(status), - }; - - if let Some(status) = status { - status.to_tlv(tw, TagType::Anonymous)?; - } - - Ok(()) + Ok(true) } - #[cfg(feature = "nightly")] - pub async fn handle_read_async( - item: Result, AttrStatus>, + pub async fn handle_write( + item: &Result<(AttrDetails<'_>, TLVElement<'_>), AttrStatus>, handler: &T, tw: &mut TLVWriter<'_, '_>, - ) -> Result, Error> { + ) -> Result<(), Error> { let status = match item { - Ok(attr) => { - let encoder = AttrDataEncoder::new(&attr, tw); + Ok((attr, data)) => { + let result = { + #[cfg(not(feature = "nightly"))] + { + handler.write(attr, AttrData::new(attr.dataver, data)) + } - match handler.read(&attr, encoder).await { - Ok(()) => None, - Err(e) => { - if e.code() == ErrorCode::NoSpace { - return Ok(Some(attr.path().to_gp())); - } else { - attr.status(e.into())? - } + #[cfg(feature = "nightly")] + { + handler + .write(&attr, AttrData::new(attr.dataver, &data)) + .await } + }; + + match result { + Ok(()) => attr.status(IMStatusCode::Success)?, + Err(error) => attr.status(error.into())?, } } - Err(status) => Some(status), + Err(status) => Some(status.clone()), }; if let Some(status) = status { - AttrResp::Status(status).to_tlv(tw, TagType::Anonymous)?; - } - - Ok(None) - } - - #[cfg(feature = "nightly")] - pub async fn handle_write_async( - item: Result<(AttrDetails<'_>, TLVElement<'_>), AttrStatus>, - handler: &mut T, - tw: &mut TLVWriter<'_, '_>, - ) -> Result<(), Error> { - let status = match item { - Ok((attr, data)) => match handler - .write(&attr, AttrData::new(attr.dataver, &data)) - .await - { - Ok(()) => attr.status(IMStatusCode::Success)?, - Err(error) => attr.status(error.into())?, - }, - Err(status) => Some(status), - }; - - if let Some(status) = status { - AttrResp::Status(status).to_tlv(tw, TagType::Anonymous)?; + status.to_tlv(tw, TagType::Anonymous)?; } Ok(()) @@ -365,18 +338,30 @@ pub struct CmdDataEncoder<'a, 'b, 'c> { } impl<'a, 'b, 'c> CmdDataEncoder<'a, 'b, 'c> { - pub fn handle( - item: Result<(CmdDetails, TLVElement), CmdStatus>, - handler: &mut T, - transaction: &mut Transaction, - tw: &mut TLVWriter, + pub async fn handle( + item: &Result<(CmdDetails<'_>, TLVElement<'_>), CmdStatus>, + handler: &T, + tw: &mut TLVWriter<'_, '_>, + exchange: &Exchange<'_>, ) -> Result<(), Error> { let status = match item { Ok((cmd, data)) => { let mut tracker = CmdDataTracker::new(); - let encoder = CmdDataEncoder::new(&cmd, &mut tracker, tw); + let encoder = CmdDataEncoder::new(cmd, &mut tracker, tw); + + let result = { + #[cfg(not(feature = "nightly"))] + { + handler.invoke(exchange, cmd, data, encoder) + } + + #[cfg(feature = "nightly")] + { + handler.invoke(exchange, &cmd, &data, encoder).await + } + }; - match handler.invoke(transaction, &cmd, &data, encoder) { + match result { Ok(()) => cmd.success(&tracker), Err(error) => { error!("Error invoking command: {}", error); @@ -386,35 +371,8 @@ impl<'a, 'b, 'c> CmdDataEncoder<'a, 'b, 'c> { } Err(status) => { error!("Error invoking command: {:?}", status); - Some(status) - } - }; - - if let Some(status) = status { - InvResp::Status(status).to_tlv(tw, TagType::Anonymous)?; - } - - Ok(()) - } - - #[cfg(feature = "nightly")] - pub async fn handle_async( - item: Result<(CmdDetails<'_>, TLVElement<'_>), CmdStatus>, - handler: &mut T, - transaction: &mut Transaction<'_, '_>, - tw: &mut TLVWriter<'_, '_>, - ) -> Result<(), Error> { - let status = match item { - Ok((cmd, data)) => { - let mut tracker = CmdDataTracker::new(); - let encoder = CmdDataEncoder::new(&cmd, &mut tracker, tw); - - match handler.invoke(transaction, &cmd, &data, encoder).await { - Ok(()) => cmd.success(&tracker), - Err(error) => cmd.status(error.into()), - } + Some(status.clone()) } - Err(status) => Some(status), }; if let Some(status) = status { diff --git a/matter/src/data_model/objects/handler.rs b/matter/src/data_model/objects/handler.rs index 143cad87..03cac3fa 100644 --- a/matter/src/data_model/objects/handler.rs +++ b/matter/src/data_model/objects/handler.rs @@ -17,12 +17,25 @@ use crate::{ error::{Error, ErrorCode}, - interaction_model::core::Transaction, tlv::TLVElement, + transport::exchange::Exchange, }; use super::{AttrData, AttrDataEncoder, AttrDetails, CmdDataEncoder, CmdDetails}; +#[cfg(feature = "nightly")] +pub use asynch::*; + +#[cfg(not(feature = "nightly"))] +pub trait DataModelHandler: super::Metadata + Handler {} +#[cfg(not(feature = "nightly"))] +impl DataModelHandler for T where T: super::Metadata + Handler {} + +#[cfg(feature = "nightly")] +pub trait DataModelHandler: super::asynch::AsyncMetadata + asynch::AsyncHandler {} +#[cfg(feature = "nightly")] +impl DataModelHandler for T where T: super::asynch::AsyncMetadata + asynch::AsyncHandler {} + pub trait ChangeNotifier { fn consume_change(&mut self) -> Option; } @@ -30,13 +43,13 @@ pub trait ChangeNotifier { pub trait Handler { fn read(&self, attr: &AttrDetails, encoder: AttrDataEncoder) -> Result<(), Error>; - fn write(&mut self, _attr: &AttrDetails, _data: AttrData) -> Result<(), Error> { + fn write(&self, _attr: &AttrDetails, _data: AttrData) -> Result<(), Error> { Err(ErrorCode::AttributeNotFound.into()) } fn invoke( - &mut self, - _transaction: &mut Transaction, + &self, + _exchange: &Exchange, _cmd: &CmdDetails, _data: &TLVElement, _encoder: CmdDataEncoder, @@ -45,6 +58,29 @@ pub trait Handler { } } +impl Handler for &T +where + T: Handler, +{ + fn read(&self, attr: &AttrDetails, encoder: AttrDataEncoder) -> Result<(), Error> { + (**self).read(attr, encoder) + } + + fn write(&self, attr: &AttrDetails, data: AttrData) -> Result<(), Error> { + (**self).write(attr, data) + } + + fn invoke( + &self, + exchange: &Exchange, + cmd: &CmdDetails, + data: &TLVElement, + encoder: CmdDataEncoder, + ) -> Result<(), Error> { + (**self).invoke(exchange, cmd, data, encoder) + } +} + impl Handler for &mut T where T: Handler, @@ -53,25 +89,52 @@ where (**self).read(attr, encoder) } - fn write(&mut self, attr: &AttrDetails, data: AttrData) -> Result<(), Error> { + fn write(&self, attr: &AttrDetails, data: AttrData) -> Result<(), Error> { (**self).write(attr, data) } fn invoke( - &mut self, - transaction: &mut Transaction, + &self, + exchange: &Exchange, cmd: &CmdDetails, data: &TLVElement, encoder: CmdDataEncoder, ) -> Result<(), Error> { - (**self).invoke(transaction, cmd, data, encoder) + (**self).invoke(exchange, cmd, data, encoder) } } pub trait NonBlockingHandler: Handler {} +impl NonBlockingHandler for &T where T: NonBlockingHandler {} + impl NonBlockingHandler for &mut T where T: NonBlockingHandler {} +impl Handler for (M, H) +where + H: Handler, +{ + fn read(&self, attr: &AttrDetails, encoder: AttrDataEncoder) -> Result<(), Error> { + self.1.read(attr, encoder) + } + + fn write(&self, attr: &AttrDetails, data: AttrData) -> Result<(), Error> { + self.1.write(attr, data) + } + + fn invoke( + &self, + exchange: &Exchange, + cmd: &CmdDetails, + data: &TLVElement, + encoder: CmdDataEncoder, + ) -> Result<(), Error> { + self.1.invoke(exchange, cmd, data, encoder) + } +} + +impl NonBlockingHandler for (M, H) where H: NonBlockingHandler {} + pub struct EmptyHandler; impl EmptyHandler { @@ -140,7 +203,7 @@ where } } - fn write(&mut self, attr: &AttrDetails, data: AttrData) -> Result<(), Error> { + fn write(&self, attr: &AttrDetails, data: AttrData) -> Result<(), Error> { if self.handler_endpoint == attr.endpoint_id && self.handler_cluster == attr.cluster_id { self.handler.write(attr, data) } else { @@ -149,16 +212,16 @@ where } fn invoke( - &mut self, - transaction: &mut Transaction, + &self, + exchange: &Exchange, cmd: &CmdDetails, data: &TLVElement, encoder: CmdDataEncoder, ) -> Result<(), Error> { if self.handler_endpoint == cmd.endpoint_id && self.handler_cluster == cmd.cluster_id { - self.handler.invoke(transaction, cmd, data, encoder) + self.handler.invoke(exchange, cmd, data, encoder) } else { - self.next.invoke(transaction, cmd, data, encoder) + self.next.invoke(exchange, cmd, data, encoder) } } } @@ -184,6 +247,35 @@ where } } +/// Wrap your `NonBlockingHandler` or `AsyncHandler` implementation in this struct +/// to get your code compilable with and without the `nightly` feature +pub struct HandlerCompat(pub T); + +impl Handler for HandlerCompat +where + T: Handler, +{ + fn read(&self, attr: &AttrDetails, encoder: AttrDataEncoder) -> Result<(), Error> { + self.0.read(attr, encoder) + } + + fn write(&self, attr: &AttrDetails, data: AttrData) -> Result<(), Error> { + self.0.write(attr, data) + } + + fn invoke( + &self, + exchange: &Exchange, + cmd: &CmdDetails, + data: &TLVElement, + encoder: CmdDataEncoder, + ) -> Result<(), Error> { + self.0.invoke(exchange, cmd, data, encoder) + } +} + +impl NonBlockingHandler for HandlerCompat where T: NonBlockingHandler {} + #[allow(unused_macros)] #[macro_export] macro_rules! handler_chain_type { @@ -203,15 +295,15 @@ macro_rules! handler_chain_type { } #[cfg(feature = "nightly")] -pub mod asynch { +mod asynch { use crate::{ data_model::objects::{AttrData, AttrDataEncoder, AttrDetails, CmdDataEncoder, CmdDetails}, error::{Error, ErrorCode}, - interaction_model::core::Transaction, tlv::TLVElement, + transport::exchange::Exchange, }; - use super::{ChainedHandler, EmptyHandler, Handler, NonBlockingHandler}; + use super::{ChainedHandler, EmptyHandler, Handler, HandlerCompat, NonBlockingHandler}; pub trait AsyncHandler { async fn read<'a>( @@ -221,7 +313,7 @@ pub mod asynch { ) -> Result<(), Error>; async fn write<'a>( - &'a mut self, + &'a self, _attr: &'a AttrDetails<'_>, _data: AttrData<'a>, ) -> Result<(), Error> { @@ -229,8 +321,8 @@ pub mod asynch { } async fn invoke<'a>( - &'a mut self, - _transaction: &'a mut Transaction<'_, '_>, + &'a self, + _exchange: &'a Exchange<'_>, _cmd: &'a CmdDetails<'_>, _data: &'a TLVElement<'_>, _encoder: CmdDataEncoder<'a, '_, '_>, @@ -252,7 +344,38 @@ pub mod asynch { } async fn write<'a>( - &'a mut self, + &'a self, + attr: &'a AttrDetails<'_>, + data: AttrData<'a>, + ) -> Result<(), Error> { + (**self).write(attr, data).await + } + + async fn invoke<'a>( + &'a self, + exchange: &'a Exchange<'_>, + cmd: &'a CmdDetails<'_>, + data: &'a TLVElement<'_>, + encoder: CmdDataEncoder<'a, '_, '_>, + ) -> Result<(), Error> { + (**self).invoke(exchange, cmd, data, encoder).await + } + } + + impl AsyncHandler for &T + where + T: AsyncHandler, + { + async fn read<'a>( + &'a self, + attr: &'a AttrDetails<'_>, + encoder: AttrDataEncoder<'a, '_, '_>, + ) -> Result<(), Error> { + (**self).read(attr, encoder).await + } + + async fn write<'a>( + &'a self, attr: &'a AttrDetails<'_>, data: AttrData<'a>, ) -> Result<(), Error> { @@ -260,19 +383,48 @@ pub mod asynch { } async fn invoke<'a>( - &'a mut self, - transaction: &'a mut Transaction<'_, '_>, + &'a self, + exchange: &'a Exchange<'_>, cmd: &'a CmdDetails<'_>, data: &'a TLVElement<'_>, encoder: CmdDataEncoder<'a, '_, '_>, ) -> Result<(), Error> { - (**self).invoke(transaction, cmd, data, encoder).await + (**self).invoke(exchange, cmd, data, encoder).await } } - pub struct Asyncify(pub T); + impl AsyncHandler for (M, H) + where + H: AsyncHandler, + { + async fn read<'a>( + &'a self, + attr: &'a AttrDetails<'_>, + encoder: AttrDataEncoder<'a, '_, '_>, + ) -> Result<(), Error> { + self.1.read(attr, encoder).await + } + + async fn write<'a>( + &'a self, + attr: &'a AttrDetails<'_>, + data: AttrData<'a>, + ) -> Result<(), Error> { + self.1.write(attr, data).await + } + + async fn invoke<'a>( + &'a self, + exchange: &'a Exchange<'_>, + cmd: &'a CmdDetails<'_>, + data: &'a TLVElement<'_>, + encoder: CmdDataEncoder<'a, '_, '_>, + ) -> Result<(), Error> { + self.1.invoke(exchange, cmd, data, encoder).await + } + } - impl AsyncHandler for Asyncify + impl AsyncHandler for HandlerCompat where T: NonBlockingHandler, { @@ -285,21 +437,21 @@ pub mod asynch { } async fn write<'a>( - &'a mut self, + &'a self, attr: &'a AttrDetails<'_>, data: AttrData<'a>, ) -> Result<(), Error> { - Handler::write(&mut self.0, attr, data) + Handler::write(&self.0, attr, data) } async fn invoke<'a>( - &'a mut self, - transaction: &'a mut Transaction<'_, '_>, + &'a self, + exchange: &'a Exchange<'_>, cmd: &'a CmdDetails<'_>, data: &'a TLVElement<'_>, encoder: CmdDataEncoder<'a, '_, '_>, ) -> Result<(), Error> { - Handler::invoke(&mut self.0, transaction, cmd, data, encoder) + Handler::invoke(&self.0, exchange, cmd, data, encoder) } } @@ -332,7 +484,7 @@ pub mod asynch { } async fn write<'a>( - &'a mut self, + &'a self, attr: &'a AttrDetails<'_>, data: AttrData<'a>, ) -> Result<(), Error> { @@ -345,16 +497,16 @@ pub mod asynch { } async fn invoke<'a>( - &'a mut self, - transaction: &'a mut Transaction<'_, '_>, + &'a self, + exchange: &'a Exchange<'_>, cmd: &'a CmdDetails<'_>, data: &'a TLVElement<'_>, encoder: CmdDataEncoder<'a, '_, '_>, ) -> Result<(), Error> { if self.handler_endpoint == cmd.endpoint_id && self.handler_cluster == cmd.cluster_id { - self.handler.invoke(transaction, cmd, data, encoder).await + self.handler.invoke(exchange, cmd, data, encoder).await } else { - self.next.invoke(transaction, cmd, data, encoder).await + self.next.invoke(exchange, cmd, data, encoder).await } } } diff --git a/matter/src/data_model/objects/metadata.rs b/matter/src/data_model/objects/metadata.rs new file mode 100644 index 00000000..368ff9b6 --- /dev/null +++ b/matter/src/data_model/objects/metadata.rs @@ -0,0 +1,178 @@ +use crate::data_model::objects::Node; + +#[cfg(feature = "nightly")] +pub use asynch::*; + +use super::HandlerCompat; + +pub trait MetadataGuard { + fn node(&self) -> Node<'_>; +} + +impl MetadataGuard for &T +where + T: MetadataGuard, +{ + fn node(&self) -> Node<'_> { + (**self).node() + } +} + +impl MetadataGuard for &mut T +where + T: MetadataGuard, +{ + fn node(&self) -> Node<'_> { + (**self).node() + } +} + +pub trait Metadata { + type MetadataGuard<'a>: MetadataGuard + where + Self: 'a; + + fn lock(&self) -> Self::MetadataGuard<'_>; +} + +impl Metadata for &T +where + T: Metadata, +{ + type MetadataGuard<'a> = T::MetadataGuard<'a> where Self: 'a; + + fn lock(&self) -> Self::MetadataGuard<'_> { + (**self).lock() + } +} + +impl Metadata for &mut T +where + T: Metadata, +{ + type MetadataGuard<'a> = T::MetadataGuard<'a> where Self: 'a; + + fn lock(&self) -> Self::MetadataGuard<'_> { + (**self).lock() + } +} + +impl<'a> MetadataGuard for Node<'a> { + fn node(&self) -> Node<'_> { + Node { + id: self.id, + endpoints: self.endpoints, + } + } +} + +impl<'a> Metadata for Node<'a> { + type MetadataGuard<'g> = Node<'g> where Self: 'g; + + fn lock(&self) -> Self::MetadataGuard<'_> { + Node { + id: self.id, + endpoints: self.endpoints, + } + } +} + +impl Metadata for (M, H) +where + M: Metadata, +{ + type MetadataGuard<'a> = M::MetadataGuard<'a> + where + Self: 'a; + + fn lock(&self) -> Self::MetadataGuard<'_> { + self.0.lock() + } +} + +impl Metadata for HandlerCompat +where + T: Metadata, +{ + type MetadataGuard<'a> = T::MetadataGuard<'a> + where + Self: 'a; + + fn lock(&self) -> Self::MetadataGuard<'_> { + self.0.lock() + } +} + +#[cfg(feature = "nightly")] +pub mod asynch { + use crate::data_model::objects::{HandlerCompat, Node}; + + use super::{Metadata, MetadataGuard}; + + pub trait AsyncMetadata { + type MetadataGuard<'a>: MetadataGuard + where + Self: 'a; + + async fn lock(&self) -> Self::MetadataGuard<'_>; + } + + impl AsyncMetadata for &T + where + T: AsyncMetadata, + { + type MetadataGuard<'a> = T::MetadataGuard<'a> where Self: 'a; + + async fn lock(&self) -> Self::MetadataGuard<'_> { + (**self).lock().await + } + } + + impl AsyncMetadata for &mut T + where + T: AsyncMetadata, + { + type MetadataGuard<'a> = T::MetadataGuard<'a> where Self: 'a; + + async fn lock(&self) -> Self::MetadataGuard<'_> { + (**self).lock().await + } + } + + impl<'a> AsyncMetadata for Node<'a> { + type MetadataGuard<'g> = Node<'g> where Self: 'g; + + async fn lock(&self) -> Self::MetadataGuard<'_> { + Node { + id: self.id, + endpoints: self.endpoints, + } + } + } + + impl AsyncMetadata for (M, H) + where + M: AsyncMetadata, + { + type MetadataGuard<'a> = M::MetadataGuard<'a> + where + Self: 'a; + + async fn lock(&self) -> Self::MetadataGuard<'_> { + self.0.lock().await + } + } + + impl AsyncMetadata for HandlerCompat + where + T: Metadata, + { + type MetadataGuard<'a> = T::MetadataGuard<'a> + where + Self: 'a; + + async fn lock(&self) -> Self::MetadataGuard<'_> { + self.0.lock() + } + } +} diff --git a/matter/src/data_model/objects/mod.rs b/matter/src/data_model/objects/mod.rs index 1bd326e4..b8b05112 100644 --- a/matter/src/data_model/objects/mod.rs +++ b/matter/src/data_model/objects/mod.rs @@ -41,6 +41,9 @@ pub use handler::*; mod dataver; pub use dataver::*; +mod metadata; +pub use metadata::*; + pub type EndptId = u16; pub type ClusterId = u32; pub type AttrId = u16; diff --git a/matter/src/data_model/objects/node.rs b/matter/src/data_model/objects/node.rs index 41720b61..1ffa8967 100644 --- a/matter/src/data_model/objects/node.rs +++ b/matter/src/data_model/objects/node.rs @@ -17,9 +17,10 @@ use crate::{ acl::Accessor, + alloc, data_model::objects::Endpoint, interaction_model::{ - core::{IMStatusCode, ResumeReadReq, ResumeSubscribeReq}, + core::IMStatusCode, messages::{ ib::{AttrPath, AttrStatus, CmdStatus, DataVersionFilter}, msg::{InvReq, ReadReq, SubscribeReq, WriteReq}, @@ -27,7 +28,7 @@ use crate::{ }, }, // TODO: This layer shouldn't really depend on the TLV layer, should create an abstraction layer - tlv::{TLVArray, TLVArrayIter, TLVElement}, + tlv::{TLVArray, TLVElement}, }; use core::{ fmt, @@ -57,41 +58,6 @@ where } } -pub trait Iterable { - type Item; - - type Iterator<'a>: Iterator - where - Self: 'a; - - fn iter(&self) -> Self::Iterator<'_>; -} - -impl<'a> Iterable for Option<&'a TLVArray<'a, DataVersionFilter>> { - type Item = DataVersionFilter; - - type Iterator<'i> = WildcardIter, DataVersionFilter> where Self: 'i; - - fn iter(&self) -> Self::Iterator<'_> { - if let Some(filters) = self { - WildcardIter::Wildcard(filters.iter()) - } else { - WildcardIter::None - } - } -} - -impl<'a> Iterable for &'a [DataVersionFilter] { - type Item = DataVersionFilter; - - type Iterator<'i> = core::iter::Cloned> where Self: 'i; - - fn iter(&self) -> Self::Iterator<'_> { - let slice: &[DataVersionFilter] = self; - slice.iter().cloned() - } -} - #[derive(Debug, Clone)] pub struct Node<'a> { pub id: u16, @@ -102,6 +68,7 @@ impl<'a> Node<'a> { pub fn read<'s, 'm>( &'s self, req: &'m ReadReq, + from: Option, accessor: &'m Accessor<'m>, ) -> impl Iterator> + 'm where @@ -114,30 +81,14 @@ impl<'a> Node<'a> { req.dataver_filters.as_ref(), req.fabric_filtered, accessor, - None, - ) - } - - pub fn resume_read<'s, 'm>( - &'s self, - req: &'m ResumeReadReq, - accessor: &'m Accessor<'m>, - ) -> impl Iterator> + 'm - where - 's: 'm, - { - self.read_attr_requests( - req.paths.iter().cloned(), - req.filters.as_slice(), - req.fabric_filtered, - accessor, - Some(req.resume_path.clone()), + from, ) } pub fn subscribing_read<'s, 'm>( &'s self, req: &'m SubscribeReq, + from: Option, accessor: &'m Accessor<'m>, ) -> impl Iterator> + 'm where @@ -150,31 +101,14 @@ impl<'a> Node<'a> { req.dataver_filters.as_ref(), req.fabric_filtered, accessor, - None, + from, ) } - pub fn resume_subscribing_read<'s, 'm>( - &'s self, - req: &'m ResumeSubscribeReq, - accessor: &'m Accessor<'m>, - ) -> impl Iterator> + 'm - where - 's: 'm, - { - self.read_attr_requests( - req.paths.iter().cloned(), - req.filters.as_slice(), - req.fabric_filtered, - accessor, - Some(req.resume_path.clone().unwrap()), - ) - } - - fn read_attr_requests<'s, 'm, P, D>( + fn read_attr_requests<'s, 'm, P>( &'s self, attr_requests: P, - dataver_filters: D, + dataver_filters: Option<&'m TLVArray>, fabric_filtered: bool, accessor: &'m Accessor<'m>, from: Option, @@ -182,11 +116,9 @@ impl<'a> Node<'a> { where 's: 'm, P: Iterator + 'm, - D: Iterable + Clone + 'm, { - attr_requests.flat_map(move |path| { + alloc!(attr_requests.flat_map(move |path| { if path.to_gp().is_wildcard() { - let dataver_filters = dataver_filters.clone(); let from = from.clone(); let iter = self @@ -204,10 +136,14 @@ impl<'a> Node<'a> { .is_ok() }) .map(move |(ep, cl, attr)| { - let dataver = dataver_filters.iter().find_map(|filter| { - (filter.path.endpoint == ep.id && filter.path.cluster == cl.id) - .then_some(filter.data_ver) - }); + let dataver = if let Some(dataver_filters) = dataver_filters { + dataver_filters.iter().find_map(|filter| { + (filter.path.endpoint == ep.id && filter.path.cluster == cl.id) + .then_some(filter.data_ver) + }) + } else { + None + }; Ok(AttrDetails { node: self, @@ -230,10 +166,14 @@ impl<'a> Node<'a> { let result = match self.check_attribute(accessor, ep, cl, attr, false) { Ok(()) => { - let dataver = dataver_filters.iter().find_map(|filter| { - (filter.path.endpoint == ep && filter.path.cluster == cl) - .then_some(filter.data_ver) - }); + let dataver = if let Some(dataver_filters) = dataver_filters { + dataver_filters.iter().find_map(|filter| { + (filter.path.endpoint == ep && filter.path.cluster == cl) + .then_some(filter.data_ver) + }) + } else { + None + }; Ok(AttrDetails { node: self, @@ -252,7 +192,7 @@ impl<'a> Node<'a> { WildcardIter::Single(once(result)) } - }) + })) } pub fn write<'m>( @@ -260,7 +200,7 @@ impl<'a> Node<'a> { req: &'m WriteReq, accessor: &'m Accessor<'m>, ) -> impl Iterator), AttrStatus>> + 'm { - req.write_requests.iter().flat_map(move |attr_data| { + alloc!(req.write_requests.iter().flat_map(move |attr_data| { if attr_data.path.cluster.is_none() { WildcardIter::Single(once(Err(AttrStatus::new( &attr_data.path.to_gp(), @@ -332,7 +272,7 @@ impl<'a> Node<'a> { WildcardIter::Single(once(result)) } - }) + })) } pub fn invoke<'m>( @@ -340,7 +280,8 @@ impl<'a> Node<'a> { req: &'m InvReq, accessor: &'m Accessor<'m>, ) -> impl Iterator), CmdStatus>> + 'm { - req.inv_requests + alloc!(req + .inv_requests .iter() .flat_map(|inv_requests| inv_requests.iter()) .flat_map(move |cmd_data| { @@ -393,7 +334,7 @@ impl<'a> Node<'a> { WildcardIter::Single(once(result)) } - }) + })) } fn matches(path: Option<&GenericPath>, ep: EndptId, cl: ClusterId, leaf: u32) -> bool { diff --git a/matter/src/data_model/root_endpoint.rs b/matter/src/data_model/root_endpoint.rs index 21691cdf..69df3bd8 100644 --- a/matter/src/data_model/root_endpoint.rs +++ b/matter/src/data_model/root_endpoint.rs @@ -46,7 +46,7 @@ pub const CLUSTERS: [Cluster<'static>; 7] = [ access_control::CLUSTER, ]; -pub fn endpoint(id: EndptId) -> Endpoint<'static> { +pub const fn endpoint(id: EndptId) -> Endpoint<'static> { Endpoint { id, device_type: super::device_types::DEV_TYPE_ROOT_NODE, diff --git a/matter/src/data_model/sdm/admin_commissioning.rs b/matter/src/data_model/sdm/admin_commissioning.rs index 15c803f4..3cce0f73 100644 --- a/matter/src/data_model/sdm/admin_commissioning.rs +++ b/matter/src/data_model/sdm/admin_commissioning.rs @@ -19,11 +19,11 @@ use core::cell::RefCell; use core::convert::TryInto; use crate::data_model::objects::*; -use crate::interaction_model::core::Transaction; use crate::mdns::Mdns; use crate::secure_channel::pake::PaseMgr; use crate::secure_channel::spake2p::VerifierData; use crate::tlv::{FromTLV, Nullable, OctetStr, TLVElement}; +use crate::transport::exchange::Exchange; use crate::utils::rand::Rand; use crate::{attribute_enum, cmd_enter}; use crate::{command_enum, error::*}; @@ -84,8 +84,8 @@ pub const CLUSTER: Cluster<'static> = Cluster { ], commands: &[ Commands::OpenCommWindow as _, - Commands::OpenBasicCommWindow as _, - Commands::RevokeComm as _, + // Commands::OpenBasicCommWindow as _, + // Commands::RevokeComm as _, ], }; @@ -133,7 +133,7 @@ impl<'a> AdminCommCluster<'a> { } pub fn invoke( - &mut self, + &self, cmd: &CmdDetails, data: &TLVElement, _encoder: CmdDataEncoder, @@ -148,7 +148,7 @@ impl<'a> AdminCommCluster<'a> { Ok(()) } - fn handle_command_opencomm_win(&mut self, data: &TLVElement) -> Result<(), Error> { + fn handle_command_opencomm_win(&self, data: &TLVElement) -> Result<(), Error> { cmd_enter!("Open Commissioning Window"); let req = OpenCommWindowReq::from_tlv(data)?; let verifier = VerifierData::new(req.verifier.0, req.iterations, req.salt.0); @@ -166,8 +166,8 @@ impl<'a> Handler for AdminCommCluster<'a> { } fn invoke( - &mut self, - _transaction: &mut Transaction, + &self, + _exchange: &Exchange, cmd: &CmdDetails, data: &TLVElement, encoder: CmdDataEncoder, diff --git a/matter/src/data_model/sdm/general_commissioning.rs b/matter/src/data_model/sdm/general_commissioning.rs index 78c3bef3..0784baec 100644 --- a/matter/src/data_model/sdm/general_commissioning.rs +++ b/matter/src/data_model/sdm/general_commissioning.rs @@ -20,8 +20,8 @@ use core::convert::TryInto; use crate::data_model::objects::*; use crate::data_model::sdm::failsafe::FailSafe; -use crate::interaction_model::core::Transaction; use crate::tlv::{FromTLV, TLVElement, TLVWriter, TagType, ToTLV, UtfStr}; +use crate::transport::exchange::Exchange; use crate::utils::rand::Rand; use crate::{attribute_enum, cmd_enter}; use crate::{command_enum, error::*}; @@ -171,19 +171,19 @@ impl<'a> GenCommCluster<'a> { } pub fn invoke( - &mut self, - transaction: &mut Transaction, + &self, + exchange: &Exchange, cmd: &CmdDetails, data: &TLVElement, encoder: CmdDataEncoder, ) -> Result<(), Error> { match cmd.cmd_id.try_into()? { - Commands::ArmFailsafe => self.handle_command_armfailsafe(transaction, data, encoder)?, + Commands::ArmFailsafe => self.handle_command_armfailsafe(exchange, data, encoder)?, Commands::SetRegulatoryConfig => { - self.handle_command_setregulatoryconfig(transaction, data, encoder)? + self.handle_command_setregulatoryconfig(exchange, data, encoder)? } Commands::CommissioningComplete => { - self.handle_command_commissioningcomplete(transaction, encoder)?; + self.handle_command_commissioningcomplete(exchange, encoder)?; } } @@ -193,8 +193,8 @@ impl<'a> GenCommCluster<'a> { } fn handle_command_armfailsafe( - &mut self, - transaction: &mut Transaction, + &self, + exchange: &Exchange, data: &TLVElement, encoder: CmdDataEncoder, ) -> Result<(), Error> { @@ -207,7 +207,7 @@ impl<'a> GenCommCluster<'a> { .borrow_mut() .arm( p.expiry_len, - transaction.session().get_session_mode().clone(), + exchange.with_session(|sess| Ok(sess.get_session_mode().clone()))?, ) .is_err() { @@ -225,13 +225,12 @@ impl<'a> GenCommCluster<'a> { .with_command(RespCommands::ArmFailsafeResp as _)? .set(cmd_data)?; - transaction.complete(); Ok(()) } fn handle_command_setregulatoryconfig( - &mut self, - transaction: &mut Transaction, + &self, + _exchange: &Exchange, data: &TLVElement, encoder: CmdDataEncoder, ) -> Result<(), Error> { @@ -252,20 +251,22 @@ impl<'a> GenCommCluster<'a> { .with_command(RespCommands::SetRegulatoryConfigResp as _)? .set(cmd_data)?; - transaction.complete(); Ok(()) } fn handle_command_commissioningcomplete( - &mut self, - transaction: &mut Transaction, + &self, + exchange: &Exchange, encoder: CmdDataEncoder, ) -> Result<(), Error> { cmd_enter!("Commissioning Complete"); let mut status: u8 = CommissioningError::Ok as u8; // Has to be a Case Session - if transaction.session().get_local_fabric_idx().is_none() { + if exchange + .with_session(|sess| Ok(sess.get_local_fabric_idx()))? + .is_none() + { status = CommissioningError::ErrInvalidAuth as u8; } @@ -274,7 +275,7 @@ impl<'a> GenCommCluster<'a> { if self .failsafe .borrow_mut() - .disarm(transaction.session().get_session_mode().clone()) + .disarm(exchange.with_session(|sess| Ok(sess.get_session_mode().clone()))?) .is_err() { status = CommissioningError::ErrInvalidAuth as u8; @@ -289,7 +290,6 @@ impl<'a> GenCommCluster<'a> { .with_command(RespCommands::CommissioningCompleteResp as _)? .set(cmd_data)?; - transaction.complete(); Ok(()) } } @@ -300,13 +300,13 @@ impl<'a> Handler for GenCommCluster<'a> { } fn invoke( - &mut self, - transaction: &mut Transaction, + &self, + exchange: &Exchange, cmd: &CmdDetails, data: &TLVElement, encoder: CmdDataEncoder, ) -> Result<(), Error> { - GenCommCluster::invoke(self, transaction, cmd, data, encoder) + GenCommCluster::invoke(self, exchange, cmd, data, encoder) } } diff --git a/matter/src/data_model/sdm/noc.rs b/matter/src/data_model/sdm/noc.rs index 7fb1e37b..8b66cb4c 100644 --- a/matter/src/data_model/sdm/noc.rs +++ b/matter/src/data_model/sdm/noc.rs @@ -24,9 +24,9 @@ use crate::crypto::{self, KeyPair}; use crate::data_model::objects::*; use crate::data_model::sdm::dev_att; use crate::fabric::{Fabric, FabricMgr, MAX_SUPPORTED_FABRICS}; -use crate::interaction_model::core::Transaction; use crate::mdns::Mdns; use crate::tlv::{FromTLV, OctetStr, TLVElement, TLVWriter, TagType, ToTLV, UtfStr}; +use crate::transport::exchange::Exchange; use crate::transport::session::SessionMode; use crate::utils::epoch::Epoch; use crate::utils::rand::Rand; @@ -289,26 +289,26 @@ impl<'a> NocCluster<'a> { } pub fn invoke( - &mut self, - transaction: &mut Transaction, + &self, + exchange: &Exchange, cmd: &CmdDetails, data: &TLVElement, encoder: CmdDataEncoder, ) -> Result<(), Error> { match cmd.cmd_id.try_into()? { - Commands::AddNOC => self.handle_command_addnoc(transaction, data, encoder)?, - Commands::CSRReq => self.handle_command_csrrequest(transaction, data, encoder)?, + Commands::AddNOC => self.handle_command_addnoc(exchange, data, encoder)?, + Commands::CSRReq => self.handle_command_csrrequest(exchange, data, encoder)?, Commands::AddTrustedRootCert => { - self.handle_command_addtrustedrootcert(transaction, data)? + self.handle_command_addtrustedrootcert(exchange, data)? } - Commands::AttReq => self.handle_command_attrequest(transaction, data, encoder)?, + Commands::AttReq => self.handle_command_attrequest(exchange, data, encoder)?, Commands::CertChainReq => { - self.handle_command_certchainrequest(transaction, data, encoder)? + self.handle_command_certchainrequest(exchange, data, encoder)? } Commands::UpdateFabricLabel => { - self.handle_command_updatefablabel(transaction, data, encoder)?; + self.handle_command_updatefablabel(exchange, data, encoder)?; } - Commands::RemoveFabric => self.handle_command_rmfabric(transaction, data, encoder)?, + Commands::RemoveFabric => self.handle_command_rmfabric(exchange, data, encoder)?, } self.data_ver.changed(); @@ -323,13 +323,12 @@ impl<'a> NocCluster<'a> { } fn _handle_command_addnoc( - &mut self, - transaction: &mut Transaction, + &self, + exchange: &Exchange, data: &TLVElement, ) -> Result { - let noc_data = transaction - .session_mut() - .take_noc_data() + let noc_data = exchange + .with_session_mut(|sess| Ok(sess.take_noc_data()))? .ok_or(NocStatus::MissingCsr)?; if !self @@ -411,42 +410,42 @@ impl<'a> NocCluster<'a> { } fn handle_command_updatefablabel( - &mut self, - transaction: &mut Transaction, + &self, + exchange: &Exchange, data: &TLVElement, encoder: CmdDataEncoder, ) -> Result<(), Error> { cmd_enter!("Update Fabric Label"); let req = UpdateFabricLabelReq::from_tlv(data).map_err(Error::map_invalid_data_type)?; - let (result, fab_idx) = - if let SessionMode::Case(c) = transaction.session().get_session_mode() { - if self - .fabric_mgr - .borrow_mut() - .set_label( - c.fab_idx, - req.label.as_str().map_err(Error::map_invalid_data_type)?, - ) - .is_err() - { - (NocStatus::LabelConflict, c.fab_idx) - } else { - (NocStatus::Ok, c.fab_idx) - } + let (result, fab_idx) = if let SessionMode::Case(c) = + exchange.with_session(|sess| Ok(sess.get_session_mode().clone()))? + { + if self + .fabric_mgr + .borrow_mut() + .set_label( + c.fab_idx, + req.label.as_str().map_err(Error::map_invalid_data_type)?, + ) + .is_err() + { + (NocStatus::LabelConflict, c.fab_idx) } else { - // Update Fabric Label not allowed - (NocStatus::InvalidFabricIndex, 0) - }; + (NocStatus::Ok, c.fab_idx) + } + } else { + // Update Fabric Label not allowed + (NocStatus::InvalidFabricIndex, 0) + }; Self::create_nocresponse(encoder, result, fab_idx, "")?; - transaction.complete(); Ok(()) } fn handle_command_rmfabric( - &mut self, - transaction: &mut Transaction, + &self, + _exchange: &Exchange, data: &TLVElement, encoder: CmdDataEncoder, ) -> Result<(), Error> { @@ -459,7 +458,7 @@ impl<'a> NocCluster<'a> { .is_ok() { let _ = self.acl_mgr.borrow_mut().delete_for_fabric(req.fab_idx); - transaction.terminate(); + // TODO: transaction.terminate(); Ok(()) } else { Self::create_nocresponse(encoder, NocStatus::InvalidFabricIndex, req.fab_idx, "") @@ -467,28 +466,27 @@ impl<'a> NocCluster<'a> { } fn handle_command_addnoc( - &mut self, - transaction: &mut Transaction, + &self, + exchange: &Exchange, data: &TLVElement, encoder: CmdDataEncoder, ) -> Result<(), Error> { cmd_enter!("AddNOC"); - let (status, fab_idx) = match self._handle_command_addnoc(transaction, data) { + let (status, fab_idx) = match self._handle_command_addnoc(exchange, data) { Ok(fab_idx) => (NocStatus::Ok, fab_idx), Err(NocError::Status(status)) => (status, 0), Err(NocError::Error(error)) => Err(error)?, }; Self::create_nocresponse(encoder, status, fab_idx, "")?; - transaction.complete(); Ok(()) } fn handle_command_attrequest( - &mut self, - transaction: &mut Transaction, + &self, + exchange: &Exchange, data: &TLVElement, encoder: CmdDataEncoder, ) -> Result<(), Error> { @@ -498,7 +496,10 @@ impl<'a> NocCluster<'a> { info!("Received Attestation Nonce:{:?}", req.str); let mut attest_challenge = [0u8; crypto::SYMM_KEY_LEN_BYTES]; - attest_challenge.copy_from_slice(transaction.session().get_att_challenge()); + exchange.with_session(|sess| { + attest_challenge.copy_from_slice(sess.get_att_challenge()); + Ok(()) + })?; let mut writer = encoder.with_command(RespCommands::AttReqResp as _)?; @@ -522,13 +523,12 @@ impl<'a> NocCluster<'a> { writer.complete()?; - transaction.complete(); Ok(()) } fn handle_command_certchainrequest( - &mut self, - transaction: &mut Transaction, + &self, + _exchange: &Exchange, data: &TLVElement, encoder: CmdDataEncoder, ) -> Result<(), Error> { @@ -549,13 +549,12 @@ impl<'a> NocCluster<'a> { .with_command(RespCommands::CertChainResp as _)? .set(cmd_data)?; - transaction.complete(); Ok(()) } fn handle_command_csrrequest( - &mut self, - transaction: &mut Transaction, + &self, + exchange: &Exchange, data: &TLVElement, encoder: CmdDataEncoder, ) -> Result<(), Error> { @@ -570,7 +569,10 @@ impl<'a> NocCluster<'a> { let noc_keypair = KeyPair::new(self.rand)?; let mut attest_challenge = [0u8; crypto::SYMM_KEY_LEN_BYTES]; - attest_challenge.copy_from_slice(transaction.session().get_att_challenge()); + exchange.with_session(|sess| { + attest_challenge.copy_from_slice(sess.get_att_challenge()); + Ok(()) + })?; let mut writer = encoder.with_command(RespCommands::CSRResp as _)?; @@ -591,15 +593,17 @@ impl<'a> NocCluster<'a> { let noc_data = NocData::new(noc_keypair); // Store this in the session data instead of cluster data, so it gets cleared // if the session goes away for some reason - transaction.session_mut().set_noc_data(noc_data); + exchange.with_session_mut(|sess| { + sess.set_noc_data(noc_data); + Ok(()) + })?; - transaction.complete(); Ok(()) } fn handle_command_addtrustedrootcert( - &mut self, - transaction: &mut Transaction, + &self, + exchange: &Exchange, data: &TLVElement, ) -> Result<(), Error> { cmd_enter!("AddTrustedRootCert"); @@ -608,25 +612,26 @@ impl<'a> NocCluster<'a> { } // This may happen on CASE or PASE. For PASE, the existence of NOC Data is necessary - match transaction.session().get_session_mode() { + match exchange.with_session(|sess| Ok(sess.get_session_mode().clone()))? { SessionMode::Case(_) => error!("CASE: AddTrustedRootCert handling pending"), // For a CASE Session, we just return success for now, SessionMode::Pase => { - let noc_data = transaction - .session_mut() - .get_noc_data() - .ok_or(ErrorCode::NoSession)?; + exchange.with_session_mut(|sess| { + let noc_data = sess.get_noc_data().ok_or(ErrorCode::NoSession)?; + + let req = CommonReq::from_tlv(data).map_err(Error::map_invalid_command)?; + info!("Received Trusted Cert:{:x?}", req.str); + + noc_data.root_ca = heapless::Vec::from_slice(req.str.0) + .map_err(|_| ErrorCode::BufferTooSmall)?; - let req = CommonReq::from_tlv(data).map_err(Error::map_invalid_command)?; - info!("Received Trusted Cert:{:x?}", req.str); + Ok(()) + })?; - noc_data.root_ca = - heapless::Vec::from_slice(req.str.0).map_err(|_| ErrorCode::BufferTooSmall)?; // TODO } _ => (), } - transaction.complete(); Ok(()) } } @@ -637,13 +642,13 @@ impl<'a> Handler for NocCluster<'a> { } fn invoke( - &mut self, - transaction: &mut Transaction, + &self, + exchange: &Exchange, cmd: &CmdDetails, data: &TLVElement, encoder: CmdDataEncoder, ) -> Result<(), Error> { - NocCluster::invoke(self, transaction, cmd, data, encoder) + NocCluster::invoke(self, exchange, cmd, data, encoder) } } diff --git a/matter/src/data_model/system_model/access_control.rs b/matter/src/data_model/system_model/access_control.rs index 17c88e33..8301b468 100644 --- a/matter/src/data_model/system_model/access_control.rs +++ b/matter/src/data_model/system_model/access_control.rs @@ -132,7 +132,7 @@ impl<'a> AccessControlCluster<'a> { } } - pub fn write(&mut self, attr: &AttrDetails, data: AttrData) -> Result<(), Error> { + pub fn write(&self, attr: &AttrDetails, data: AttrData) -> Result<(), Error> { match attr.attr_id.try_into()? { Attributes::Acl(_) => { attr_list_write(attr, data.with_dataver(self.data_ver.get())?, |op, data| { @@ -151,7 +151,7 @@ impl<'a> AccessControlCluster<'a> { /// This takes care of 4 things, add item, edit item, delete item, delete list. /// Care about fabric-scoped behaviour is taken fn write_acl_attr( - &mut self, + &self, op: &ListOperation, data: &TLVElement, fab_idx: u8, @@ -185,7 +185,7 @@ impl<'a> Handler for AccessControlCluster<'a> { AccessControlCluster::read(self, attr, encoder) } - fn write(&mut self, attr: &AttrDetails, data: AttrData) -> Result<(), Error> { + fn write(&self, attr: &AttrDetails, data: AttrData) -> Result<(), Error> { AccessControlCluster::write(self, attr, data) } } @@ -220,7 +220,7 @@ mod tests { let mut tw = TLVWriter::new(&mut writebuf); let acl_mgr = RefCell::new(AclMgr::new()); - let mut acl = AccessControlCluster::new(&acl_mgr, dummy_rand); + let acl = AccessControlCluster::new(&acl_mgr, dummy_rand); let new = AclEntry::new(2, Privilege::VIEW, AuthMode::Case); new.to_tlv(&mut tw, TagType::Anonymous).unwrap(); @@ -258,7 +258,7 @@ mod tests { for i in &verifier { acl_mgr.borrow_mut().add(i.clone()).unwrap(); } - let mut acl = AccessControlCluster::new(&acl_mgr, dummy_rand); + let acl = AccessControlCluster::new(&acl_mgr, dummy_rand); let new = AclEntry::new(2, Privilege::VIEW, AuthMode::Case); new.to_tlv(&mut tw, TagType::Anonymous).unwrap(); @@ -295,7 +295,7 @@ mod tests { for i in &input { acl_mgr.borrow_mut().add(i.clone()).unwrap(); } - let mut acl = AccessControlCluster::new(&acl_mgr, dummy_rand); + let acl = AccessControlCluster::new(&acl_mgr, dummy_rand); // data is don't-care actually let data = TLVElement::new(TagType::Anonymous, ElementType::True); diff --git a/matter/src/interaction_model/core.rs b/matter/src/interaction_model/core.rs index cc763a84..4ce35837 100644 --- a/matter/src/interaction_model/core.rs +++ b/matter/src/interaction_model/core.rs @@ -15,36 +15,28 @@ * limitations under the License. */ -use core::sync::atomic::{AtomicU32, Ordering}; use core::time::Duration; use crate::{ - data_model::core::DataHandler, + acl::Accessor, error::*, - tlv::{get_root_node_struct, print_tlv_list, FromTLV, TLVElement, TLVWriter, TagType, ToTLV}, - transport::{ - exchange::{Exchange, ExchangeCtx}, - packet::Packet, - proto_ctx::ProtoCtx, - session::Session, - }, + tlv::{get_root_node_struct, FromTLV, TLVElement, TLVWriter, TagType, ToTLV}, + transport::{exchange::Exchange, packet::Packet}, + utils::epoch::Epoch, }; -use log::{error, info}; -use num; +use log::error; +use num::{self, FromPrimitive}; use num_derive::FromPrimitive; -use owo_colors::OwoColorize; -use super::messages::{ - ib::{AttrPath, DataVersionFilter}, - msg::{self, InvReq, ReadReq, StatusResp, SubscribeReq, SubscribeResp, TimedReq, WriteReq}, - GenericPath, +use super::messages::msg::{ + self, InvReq, ReadReq, StatusResp, SubscribeReq, SubscribeResp, TimedReq, WriteReq, }; #[macro_export] macro_rules! cmd_enter { ($e:expr) => {{ use owo_colors::OwoColorize; - info! {"{} {}", "Handling Command".cyan(), $e.cyan()} + info! {"{} {}", "Handling command".cyan(), $e.cyan()} }}; } @@ -104,7 +96,7 @@ impl From for IMStatusCode { impl FromTLV<'_> for IMStatusCode { fn from_tlv(t: &TLVElement) -> Result { - num::FromPrimitive::from_u16(t.u16()?).ok_or_else(|| ErrorCode::Invalid.into()) + FromPrimitive::from_u16(t.u16()?).ok_or_else(|| ErrorCode::Invalid.into()) } } @@ -114,7 +106,7 @@ impl ToTLV for IMStatusCode { } } -#[derive(FromPrimitive, Debug, Copy, Clone, PartialEq)] +#[derive(FromPrimitive, Debug, Copy, Clone, Eq, PartialEq)] pub enum OpCode { Reserved = 0, StatusResponse = 1, @@ -129,208 +121,16 @@ pub enum OpCode { TimedRequest = 10, } -#[derive(PartialEq)] -pub enum TransactionState { - Ongoing, - Complete, - Terminate, -} -pub struct Transaction<'a, 'b> { - state: TransactionState, - ctx: &'a mut ExchangeCtx<'b>, -} - -impl<'a, 'b> Transaction<'a, 'b> { - pub fn new(ctx: &'a mut ExchangeCtx<'b>) -> Self { - Self { - state: TransactionState::Ongoing, - ctx, - } - } - - pub fn exch(&self) -> &Exchange { - self.ctx.exch - } - - pub fn exch_mut(&mut self) -> &mut Exchange { - self.ctx.exch - } - - pub fn session(&self) -> &Session { - self.ctx.sess.session() - } - - pub fn session_mut(&mut self) -> &mut Session { - self.ctx.sess.session_mut() - } - - /// Terminates the transaction, no communication (even ACKs) happens hence forth - pub fn terminate(&mut self) { - self.state = TransactionState::Terminate - } - - pub fn is_terminate(&self) -> bool { - self.state == TransactionState::Terminate - } - /// Marks the transaction as completed from the application's perspective - pub fn complete(&mut self) { - self.state = TransactionState::Complete - } - - pub fn is_complete(&self) -> bool { - self.state == TransactionState::Complete - } - - pub fn set_timeout(&mut self, timeout: u64) { - let now = (self.ctx.epoch)(); - - self.ctx - .exch - .set_data_time(now.checked_add(Duration::from_millis(timeout))); - } - - pub fn get_timeout(&mut self) -> Option { - self.ctx.exch.get_data_time() - } - - pub fn has_timed_out(&self) -> bool { - if let Some(timeout) = self.ctx.exch.get_data_time() { - if (self.ctx.epoch)() > timeout { - return true; - } - } - false - } -} - /* Interaction Model ID as per the Matter Spec */ pub const PROTO_ID_INTERACTION_MODEL: u16 = 0x01; -const MAX_RESUME_PATHS: usize = 32; -const MAX_RESUME_DATAVER_FILTERS: usize = 32; - // This is the amount of space we reserve for other things to be attached towards // the end of long reads. const LONG_READS_TLV_RESERVE_SIZE: usize = 24; -// TODO: For now... -static SUBS_ID: AtomicU32 = AtomicU32::new(1); - -pub enum Interaction<'a> { - Read(ReadReq<'a>), - Write(WriteReq<'a>), - Invoke(InvReq<'a>), - Subscribe(SubscribeReq<'a>), - Timed(TimedReq), - ResumeRead(ResumeReadReq), - ResumeSubscribe(ResumeSubscribeReq), -} - -impl<'a> Interaction<'a> { - fn new(rx: &'a Packet, transaction: &mut Transaction) -> Result, Error> { - let opcode: OpCode = rx.get_proto_opcode()?; - - let rx_data = rx.as_slice(); - - info!("{} {:?}", "Received command".cyan(), opcode); - print_tlv_list(rx_data); - - match opcode { - OpCode::ReadRequest => Ok(Some(Self::Read(ReadReq::from_tlv(&get_root_node_struct( - rx_data, - )?)?))), - OpCode::WriteRequest => Ok(Some(Self::Write(WriteReq::from_tlv( - &get_root_node_struct(rx_data)?, - )?))), - OpCode::InvokeRequest => Ok(Some(Self::Invoke(InvReq::from_tlv( - &get_root_node_struct(rx_data)?, - )?))), - OpCode::SubscribeRequest => Ok(Some(Self::Subscribe(SubscribeReq::from_tlv( - &get_root_node_struct(rx_data)?, - )?))), - OpCode::StatusResponse => { - let resp = StatusResp::from_tlv(&get_root_node_struct(rx_data)?)?; - - if resp.status == IMStatusCode::Success { - if let Some(req) = transaction.exch_mut().take_suspended_read_req() { - Ok(Some(Self::ResumeRead(req))) - } else if let Some(req) = transaction.exch_mut().take_suspended_subscribe_req() - { - Ok(Some(Self::ResumeSubscribe(req))) - } else { - Ok(None) - } - } else { - Ok(None) - } - } - OpCode::TimedRequest => Ok(Some(Self::Timed(TimedReq::from_tlv( - &get_root_node_struct(rx_data)?, - )?))), - _ => { - error!("Opcode not handled: {:?}", opcode); - Err(ErrorCode::InvalidOpcode.into()) - } - } - } - - pub fn initiate( - rx: &'a Packet, - tx: &mut Packet, - transaction: &mut Transaction, - ) -> Result, Error> { - if let Some(interaction) = Self::new(rx, transaction)? { - tx.reset(); - - let initiated = match &interaction { - Interaction::Read(req) => req.initiate(tx, transaction)?, - Interaction::Write(req) => req.initiate(tx, transaction)?, - Interaction::Invoke(req) => req.initiate(tx, transaction)?, - Interaction::Subscribe(req) => req.initiate(tx, transaction)?, - Interaction::Timed(req) => { - req.process(tx, transaction)?; - false - } - Interaction::ResumeRead(req) => req.initiate(tx, transaction)?, - Interaction::ResumeSubscribe(req) => req.initiate(tx, transaction)?, - }; - - Ok(initiated.then_some(interaction)) - } else { - Ok(None) - } - } - - fn create_status_response(tx: &mut Packet, status: IMStatusCode) -> Result<(), Error> { - tx.set_proto_id(PROTO_ID_INTERACTION_MODEL); - tx.set_proto_opcode(OpCode::StatusResponse as u8); - - let mut tw = TLVWriter::new(tx.get_writebuf()?); - - let status = StatusResp { status }; - status.to_tlv(&mut tw, TagType::Anonymous) - } -} - impl<'a> ReadReq<'a> { - fn suspend(self, resume_path: GenericPath) -> ResumeReadReq { - ResumeReadReq { - paths: self - .attr_requests - .iter() - .flat_map(|attr_requests| attr_requests.iter()) - .collect(), - filters: self - .dataver_filters - .iter() - .flat_map(|filters| filters.iter()) - .collect(), - fabric_filtered: self.fabric_filtered, - resume_path, - } - } - - fn initiate(&self, tx: &mut Packet, _transaction: &mut Transaction) -> Result { + pub fn tx_start<'r, 'p>(&self, tx: &'r mut Packet<'p>) -> Result, Error> { + tx.reset(); tx.set_proto_id(PROTO_ID_INTERACTION_MODEL); tx.set_proto_opcode(OpCode::ReportData as u8); @@ -342,47 +142,37 @@ impl<'a> ReadReq<'a> { tw.start_array(TagType::Context(msg::ReportDataTag::AttributeReports as u8))?; } - Ok(true) + Ok(tw) } - pub fn complete( - self, - tx: &mut Packet, - transaction: &mut Transaction, - resume_path: Option, - ) -> Result { + pub fn tx_finish_chunk(&self, tx: &mut Packet) -> Result<(), Error> { + self.complete(tx, true) + } + + pub fn tx_finish(&self, tx: &mut Packet) -> Result<(), Error> { + self.complete(tx, false) + } + + fn complete(&self, tx: &mut Packet<'_>, more_chunks: bool) -> Result<(), Error> { let mut tw = Self::restore_long_read_space(tx)?; if self.attr_requests.is_some() { tw.end_container()?; } - let more_chunks = if let Some(resume_path) = resume_path { + if more_chunks { tw.bool( TagType::Context(msg::ReportDataTag::MoreChunkedMsgs as u8), true, )?; - - transaction - .exch_mut() - .set_suspended_read_req(self.suspend(resume_path)); - true - } else { - false - }; + } tw.bool( TagType::Context(msg::ReportDataTag::SupressResponse as u8), !more_chunks, )?; - tw.end_container()?; - - if !more_chunks { - transaction.complete(); - } - - Ok(true) + tw.end_container() } fn reserve_long_read_space<'p, 'b>(tx: &'p mut Packet<'b>) -> Result, Error> { @@ -401,14 +191,18 @@ impl<'a> ReadReq<'a> { } impl<'a> WriteReq<'a> { - fn initiate(&self, tx: &mut Packet, transaction: &mut Transaction) -> Result { - if transaction.has_timed_out() { - Interaction::create_status_response(tx, IMStatusCode::Timeout)?; - - transaction.complete(); + pub fn tx_start<'r, 'p>( + &self, + tx: &'r mut Packet<'p>, + epoch: Epoch, + timeout: Option, + ) -> Result>, Error> { + if has_timed_out(epoch, timeout) { + Interaction::status_response(tx, IMStatusCode::Timeout)?; - Ok(false) + Ok(None) } else { + tx.reset(); tx.set_proto_id(PROTO_ID_INTERACTION_MODEL); tx.set_proto_opcode(OpCode::WriteResponse as u8); @@ -417,47 +211,40 @@ impl<'a> WriteReq<'a> { tw.start_struct(TagType::Anonymous)?; tw.start_array(TagType::Context(msg::WriteRespTag::WriteResponses as u8))?; - Ok(true) + Ok(Some(tw)) } } - pub fn complete(self, tx: &mut Packet, transaction: &mut Transaction) -> Result { - let suppress = self.supress_response.unwrap_or_default(); - + pub fn tx_finish(&self, tx: &mut Packet<'_>) -> Result<(), Error> { let mut tw = TLVWriter::new(tx.get_writebuf()?); tw.end_container()?; - tw.end_container()?; - - transaction.complete(); - - Ok(if suppress { - error!("Supress response is set, is this the expected handling?"); - false - } else { - true - }) + tw.end_container() } } impl<'a> InvReq<'a> { - fn initiate(&self, tx: &mut Packet, transaction: &mut Transaction) -> Result { - if transaction.has_timed_out() { - Interaction::create_status_response(tx, IMStatusCode::Timeout)?; - - transaction.complete(); + pub fn tx_start<'r, 'p>( + &self, + tx: &'r mut Packet<'p>, + epoch: Epoch, + timeout: Option, + ) -> Result>, Error> { + if has_timed_out(epoch, timeout) { + Interaction::status_response(tx, IMStatusCode::Timeout)?; - Ok(false) + Ok(None) } else { - let timed_tx = transaction.get_timeout().map(|_| true); + let timed_tx = timeout.map(|_| true); let timed_request = self.timed_request.filter(|a| *a); // Either both should be None, or both should be Some(true) if timed_tx != timed_request { - Interaction::create_status_response(tx, IMStatusCode::TimedRequestMisMatch)?; + Interaction::status_response(tx, IMStatusCode::TimedRequestMisMatch)?; - Ok(false) + Ok(None) } else { + tx.reset(); tx.set_proto_id(PROTO_ID_INTERACTION_MODEL); tx.set_proto_opcode(OpCode::InvokeResponse as u8); @@ -475,77 +262,45 @@ impl<'a> InvReq<'a> { tw.start_array(TagType::Context(msg::InvRespTag::InvokeResponses as u8))?; } - Ok(true) + Ok(Some(tw)) } } } - pub fn complete(self, tx: &mut Packet, _transaction: &mut Transaction) -> Result { - let suppress = self.suppress_response.unwrap_or_default(); - + pub fn tx_finish(&self, tx: &mut Packet<'_>) -> Result<(), Error> { let mut tw = TLVWriter::new(tx.get_writebuf()?); if self.inv_requests.is_some() { tw.end_container()?; } - tw.end_container()?; - - Ok(if suppress { - error!("Supress response is set, is this the expected handling?"); - false - } else { - true - }) + tw.end_container() } } impl TimedReq { - pub fn process(&self, tx: &mut Packet, transaction: &mut Transaction) -> Result<(), Error> { - tx.set_proto_id(PROTO_ID_INTERACTION_MODEL); - tx.set_proto_opcode(OpCode::StatusResponse as u8); - - let mut tw = TLVWriter::new(tx.get_writebuf()?); - - transaction.set_timeout(self.timeout.into()); - - let status = StatusResp { - status: IMStatusCode::Success, - }; + pub fn timeout(&self, epoch: Epoch) -> Duration { + epoch() + .checked_add(Duration::from_millis(self.timeout as _)) + .unwrap() + } - status.to_tlv(&mut tw, TagType::Anonymous)?; + pub fn tx_process(self, tx: &mut Packet<'_>, epoch: Epoch) -> Result { + Interaction::status_response(tx, IMStatusCode::Success)?; - Ok(()) + Ok(epoch() + .checked_add(Duration::from_millis(self.timeout as _)) + .unwrap()) } } impl<'a> SubscribeReq<'a> { - fn suspend( + pub fn tx_start<'r, 'p>( &self, - resume_path: Option, + tx: &'r mut Packet<'p>, subscription_id: u32, - ) -> ResumeSubscribeReq { - ResumeSubscribeReq { - subscription_id, - paths: self - .attr_requests - .iter() - .flat_map(|attr_requests| attr_requests.iter()) - .collect(), - filters: self - .dataver_filters - .iter() - .flat_map(|filters| filters.iter()) - .collect(), - fabric_filtered: self.fabric_filtered, - resume_path, - keep_subs: self.keep_subs, - min_int_floor: self.min_int_floor, - max_int_ceil: self.max_int_ceil, - } - } - - fn initiate(&self, tx: &mut Packet, transaction: &mut Transaction) -> Result { + ) -> Result, Error> { + tx.reset(); tx.set_proto_id(PROTO_ID_INTERACTION_MODEL); tx.set_proto_opcode(OpCode::ReportData as u8); @@ -553,9 +308,6 @@ impl<'a> SubscribeReq<'a> { tw.start_struct(TagType::Anonymous)?; - let subscription_id = SUBS_ID.fetch_add(1, Ordering::SeqCst); - transaction.exch_mut().set_subscription_id(subscription_id); - tw.u32( TagType::Context(msg::ReportDataTag::SubscriptionId as u8), subscription_id, @@ -565,282 +317,417 @@ impl<'a> SubscribeReq<'a> { tw.start_array(TagType::Context(msg::ReportDataTag::AttributeReports as u8))?; } - Ok(true) + Ok(tw) } - pub fn complete( - self, - tx: &mut Packet, - transaction: &mut Transaction, - resume_path: Option, - ) -> Result { + pub fn tx_finish_chunk(&self, tx: &mut Packet<'_>, more_chunks: bool) -> Result<(), Error> { let mut tw = ReadReq::restore_long_read_space(tx)?; if self.attr_requests.is_some() { tw.end_container()?; } - if resume_path.is_some() { + if more_chunks { tw.bool( TagType::Context(msg::ReportDataTag::MoreChunkedMsgs as u8), true, )?; } - let subscription_id = transaction.exch_mut().take_subscription_id().unwrap(); - - transaction - .exch_mut() - .set_suspended_subscribe_req(self.suspend(resume_path, subscription_id)); - tw.bool( TagType::Context(msg::ReportDataTag::SupressResponse as u8), false, )?; - tw.end_container()?; - - Ok(true) + tw.end_container() } -} -#[derive(Debug)] -pub struct ResumeReadReq { - pub paths: heapless::Vec, - pub filters: heapless::Vec, - pub fabric_filtered: bool, - pub resume_path: GenericPath, -} - -impl ResumeReadReq { - fn initiate(&self, tx: &mut Packet, _transaction: &mut Transaction) -> Result { + pub fn tx_process_final(&self, tx: &mut Packet, subscription_id: u32) -> Result<(), Error> { + tx.reset(); tx.set_proto_id(PROTO_ID_INTERACTION_MODEL); - tx.set_proto_opcode(OpCode::ReportData as u8); + tx.set_proto_opcode(OpCode::SubscribeResponse as u8); - let mut tw = ReadReq::reserve_long_read_space(tx)?; + let mut tw = TLVWriter::new(tx.get_writebuf()?); - tw.start_struct(TagType::Anonymous)?; + let resp = SubscribeResp::new(subscription_id, 40); + resp.to_tlv(&mut tw, TagType::Anonymous) + } +} - tw.start_array(TagType::Context(msg::ReportDataTag::AttributeReports as u8))?; +pub struct ReadDriver<'a, 'r, 'p> { + exchange: &'r mut Exchange<'a>, + tx: &'r mut Packet<'p>, + rx: &'r mut Packet<'p>, + completed: bool, +} - Ok(true) +impl<'a, 'r, 'p> ReadDriver<'a, 'r, 'p> { + fn new(exchange: &'r mut Exchange<'a>, tx: &'r mut Packet<'p>, rx: &'r mut Packet<'p>) -> Self { + Self { + exchange, + tx, + rx, + completed: false, + } } - pub fn complete( - mut self, - tx: &mut Packet, - transaction: &mut Transaction, - resume_path: Option, - ) -> Result { - let mut tw = ReadReq::restore_long_read_space(tx)?; + fn start(&mut self, req: &ReadReq) -> Result<(), Error> { + req.tx_start(self.tx)?; - tw.end_container()?; + Ok(()) + } - let continue_interaction = if let Some(resume_path) = resume_path { - tw.bool( - TagType::Context(msg::ReportDataTag::MoreChunkedMsgs as u8), - true, - )?; + pub fn accessor(&self) -> Result, Error> { + self.exchange.accessor() + } - self.resume_path = resume_path; - transaction.exch_mut().set_suspended_read_req(self); - true + pub fn writer(&mut self) -> Result, Error> { + if self.completed { + Err(ErrorCode::Invalid.into()) // TODO } else { - false - }; + Ok(TLVWriter::new(self.tx.get_writebuf()?)) + } + } - tw.bool( - TagType::Context(msg::ReportDataTag::SupressResponse as u8), - !continue_interaction, - )?; + pub async fn send_chunk(&mut self, req: &ReadReq<'_>) -> Result { + req.tx_finish_chunk(self.tx)?; - tw.end_container()?; + if exchange_confirm(self.exchange, self.tx, self.rx).await? != IMStatusCode::Success { + self.completed = true; + Ok(false) + } else { + req.tx_start(self.tx)?; - if !continue_interaction { - transaction.complete(); + Ok(true) } + } + + pub async fn complete(&mut self, req: &ReadReq<'_>) -> Result<(), Error> { + req.tx_finish(self.tx)?; - Ok(true) + self.exchange.send_complete(self.tx).await } } -#[derive(Debug)] -pub struct ResumeSubscribeReq { - pub subscription_id: u32, - pub paths: heapless::Vec, - pub filters: heapless::Vec, - pub fabric_filtered: bool, - pub resume_path: Option, - pub keep_subs: bool, - pub min_int_floor: u16, - pub max_int_ceil: u16, +pub struct WriteDriver<'a, 'r, 'p> { + exchange: &'r mut Exchange<'a>, + tx: &'r mut Packet<'p>, + epoch: Epoch, + timeout: Option, } -impl ResumeSubscribeReq { - fn initiate(&self, tx: &mut Packet, _transaction: &mut Transaction) -> Result { - tx.set_proto_id(PROTO_ID_INTERACTION_MODEL); - - if self.resume_path.is_some() { - tx.set_proto_opcode(OpCode::ReportData as u8); - - let mut tw = ReadReq::reserve_long_read_space(tx)?; - - tw.start_struct(TagType::Anonymous)?; - - tw.u32( - TagType::Context(msg::ReportDataTag::SubscriptionId as u8), - self.subscription_id, - )?; - - tw.start_array(TagType::Context(msg::ReportDataTag::AttributeReports as u8))?; +impl<'a, 'r, 'p> WriteDriver<'a, 'r, 'p> { + fn new( + exchange: &'r mut Exchange<'a>, + epoch: Epoch, + timeout: Option, + tx: &'r mut Packet<'p>, + ) -> Self { + Self { + exchange, + tx, + epoch, + timeout, + } + } + async fn start(&mut self, req: &WriteReq<'_>) -> Result { + if req.tx_start(self.tx, self.epoch, self.timeout)?.is_some() { Ok(true) } else { - tx.set_proto_opcode(OpCode::SubscribeResponse as u8); - - let mut tw = TLVWriter::new(tx.get_writebuf()?); - - let resp = SubscribeResp::new(self.subscription_id, 40); - resp.to_tlv(&mut tw, TagType::Anonymous)?; + self.exchange.send_complete(self.tx).await?; Ok(false) } } - pub fn complete( - mut self, - tx: &mut Packet, - transaction: &mut Transaction, - resume_path: Option, - ) -> Result { - if self.resume_path.is_none() { - // Should not get here as initiate() should've sent the subscribe response already - panic!("Subscription was already processed"); + pub fn accessor(&self) -> Result, Error> { + self.exchange.accessor() + } + + pub fn writer(&mut self) -> Result, Error> { + Ok(TLVWriter::new(self.tx.get_writebuf()?)) + } + + pub async fn complete(&mut self, req: &WriteReq<'_>) -> Result<(), Error> { + if !req.supress_response.unwrap_or_default() { + req.tx_finish(self.tx)?; + self.exchange.send_complete(self.tx).await?; } - // Completing a ReportData message + Ok(()) + } +} + +pub struct InvokeDriver<'a, 'r, 'p> { + exchange: &'r mut Exchange<'a>, + tx: &'r mut Packet<'p>, + epoch: Epoch, + timeout: Option, +} - let mut tw = ReadReq::restore_long_read_space(tx)?; +impl<'a, 'r, 'p> InvokeDriver<'a, 'r, 'p> { + fn new( + exchange: &'r mut Exchange<'a>, + epoch: Epoch, + timeout: Option, + tx: &'r mut Packet<'p>, + ) -> Self { + Self { + exchange, + tx, + epoch, + timeout, + } + } - tw.end_container()?; + async fn start(&mut self, req: &InvReq<'_>) -> Result { + if req.tx_start(self.tx, self.epoch, self.timeout)?.is_some() { + Ok(true) + } else { + self.exchange.send_complete(self.tx).await?; - if resume_path.is_some() { - tw.bool( - TagType::Context(msg::ReportDataTag::MoreChunkedMsgs as u8), - true, - )?; + Ok(false) } + } - tw.bool( - TagType::Context(msg::ReportDataTag::SupressResponse as u8), - false, - )?; + pub fn accessor(&self) -> Result, Error> { + self.exchange.accessor() + } - tw.end_container()?; + pub fn writer(&mut self) -> Result, Error> { + Ok(TLVWriter::new(self.tx.get_writebuf()?)) + } - self.resume_path = resume_path; - transaction.exch_mut().set_suspended_subscribe_req(self); + pub fn writer_exchange(&mut self) -> Result<(TLVWriter<'_, 'p>, &Exchange<'a>), Error> { + Ok((TLVWriter::new(self.tx.get_writebuf()?), (self.exchange))) + } + + pub async fn complete(&mut self, req: &InvReq<'_>) -> Result<(), Error> { + if !req.suppress_response.unwrap_or_default() { + req.tx_finish(self.tx)?; + self.exchange.send_complete(self.tx).await?; + } - Ok(true) + Ok(()) } } -pub trait InteractionHandler { - fn handle(&mut self, ctx: &mut ProtoCtx) -> Result; +pub struct SubscribeDriver<'a, 'r, 'p> { + exchange: &'r mut Exchange<'a>, + tx: &'r mut Packet<'p>, + rx: &'r mut Packet<'p>, + subscription_id: u32, + completed: bool, } -impl InteractionHandler for &mut T -where - T: InteractionHandler, -{ - fn handle(&mut self, ctx: &mut ProtoCtx) -> Result { - (**self).handle(ctx) +impl<'a, 'r, 'p> SubscribeDriver<'a, 'r, 'p> { + fn new( + exchange: &'r mut Exchange<'a>, + subscription_id: u32, + tx: &'r mut Packet<'p>, + rx: &'r mut Packet<'p>, + ) -> Self { + Self { + exchange, + tx, + rx, + subscription_id, + completed: false, + } } -} -pub struct InteractionModel(pub T); + fn start(&mut self, req: &SubscribeReq) -> Result<(), Error> { + req.tx_start(self.tx, self.subscription_id)?; -impl InteractionModel -where - T: DataHandler, -{ - pub fn handle(&mut self, ctx: &mut ProtoCtx) -> Result { - let mut transaction = Transaction::new(&mut ctx.exch_ctx); + Ok(()) + } - let reply = - if let Some(interaction) = Interaction::initiate(ctx.rx, ctx.tx, &mut transaction)? { - self.0.handle(interaction, ctx.tx, &mut transaction)? - } else { - true - }; + pub fn accessor(&self) -> Result, Error> { + self.exchange.accessor() + } - if transaction.is_complete() { - transaction.exch_mut().close(); + pub fn writer(&mut self) -> Result, Error> { + if self.completed { + Err(ErrorCode::Invalid.into()) // TODO + } else { + Ok(TLVWriter::new(self.tx.get_writebuf()?)) } + } + + pub async fn send_chunk(&mut self, req: &SubscribeReq<'_>) -> Result { + req.tx_finish_chunk(self.tx, true)?; - Ok(reply) + if exchange_confirm(self.exchange, self.tx, self.rx).await? != IMStatusCode::Success { + self.completed = true; + Ok(false) + } else { + req.tx_start(self.tx, self.subscription_id)?; + + Ok(true) + } } -} -#[cfg(feature = "nightly")] -impl InteractionModel -where - T: crate::data_model::core::asynch::AsyncDataHandler, -{ - pub async fn handle_async<'a>(&mut self, ctx: &mut ProtoCtx<'_, '_>) -> Result { - let mut transaction = Transaction::new(&mut ctx.exch_ctx); - - let reply = - if let Some(interaction) = Interaction::initiate(ctx.rx, ctx.tx, &mut transaction)? { - self.0.handle(interaction, ctx.tx, &mut transaction).await? - } else { - true - }; + pub async fn complete(&mut self, req: &SubscribeReq<'_>) -> Result<(), Error> { + if !self.completed { + req.tx_finish_chunk(self.tx, false)?; - if transaction.is_complete() { - transaction.exch_mut().close(); + if exchange_confirm(self.exchange, self.tx, self.rx).await? != IMStatusCode::Success { + self.completed = true; + } else { + req.tx_process_final(self.tx, self.subscription_id)?; + self.exchange.send_complete(self.tx).await?; + } } - Ok(reply) + Ok(()) } } -impl InteractionHandler for InteractionModel -where - T: DataHandler, -{ - fn handle(&mut self, ctx: &mut ProtoCtx) -> Result { - InteractionModel::handle(self, ctx) - } +pub enum Interaction<'a, 'r, 'p> { + Read { + req: ReadReq<'r>, + driver: ReadDriver<'a, 'r, 'p>, + }, + Write { + req: WriteReq<'r>, + driver: WriteDriver<'a, 'r, 'p>, + }, + Invoke { + req: InvReq<'r>, + driver: InvokeDriver<'a, 'r, 'p>, + }, + Subscribe { + req: SubscribeReq<'r>, + driver: SubscribeDriver<'a, 'r, 'p>, + }, } -#[cfg(feature = "nightly")] -pub mod asynch { - use crate::{ - data_model::core::asynch::AsyncDataHandler, error::Error, transport::proto_ctx::ProtoCtx, - }; +impl<'a, 'r, 'p> Interaction<'a, 'r, 'p> { + pub async fn timeout( + exchange: &mut Exchange<'_>, + rx: &mut Packet<'_>, + tx: &mut Packet<'_>, + ) -> Result, Error> { + let epoch = exchange.transport().matter().epoch; - use super::InteractionModel; + let mut opcode: OpCode = rx.get_proto_opcode()?; - pub trait AsyncInteractionHandler { - async fn handle(&mut self, ctx: &mut ProtoCtx<'_, '_>) -> Result; - } + let mut timeout = None; - impl AsyncInteractionHandler for &mut T - where - T: AsyncInteractionHandler, - { - async fn handle(&mut self, ctx: &mut ProtoCtx<'_, '_>) -> Result { - (**self).handle(ctx).await + while opcode == OpCode::TimedRequest { + let rx_data = rx.as_slice(); + let req = TimedReq::from_tlv(&get_root_node_struct(rx_data)?)?; + + timeout = Some(req.tx_process(tx, epoch)?); + + exchange.exchange(tx, rx).await?; + + opcode = rx.get_proto_opcode()?; } + + Ok(timeout) } - impl AsyncInteractionHandler for InteractionModel + #[inline(always)] + pub fn new( + exchange: &'r mut Exchange<'a>, + rx: &'r mut Packet<'p>, + tx: &'r mut Packet<'p>, + rx_status: &'r mut Packet<'p>, + subscription_id: S, + timeout: Option, + ) -> Result, Error> where - T: AsyncDataHandler, + S: FnOnce() -> u32, { - async fn handle(&mut self, ctx: &mut ProtoCtx<'_, '_>) -> Result { - InteractionModel::handle_async(self, ctx).await + let epoch = exchange.transport().matter().epoch; + + let opcode = rx.get_proto_opcode()?; + let rx_data = rx.as_slice(); + + match opcode { + OpCode::ReadRequest => { + let req = ReadReq::from_tlv(&get_root_node_struct(rx_data)?)?; + let driver = ReadDriver::new(exchange, tx, rx_status); + + Ok(Self::Read { req, driver }) + } + OpCode::WriteRequest => { + let req = WriteReq::from_tlv(&get_root_node_struct(rx_data)?)?; + let driver = WriteDriver::new(exchange, epoch, timeout, tx); + + Ok(Self::Write { req, driver }) + } + OpCode::InvokeRequest => { + let req = InvReq::from_tlv(&get_root_node_struct(rx_data)?)?; + let driver = InvokeDriver::new(exchange, epoch, timeout, tx); + + Ok(Self::Invoke { req, driver }) + } + OpCode::SubscribeRequest => { + let req = SubscribeReq::from_tlv(&get_root_node_struct(rx_data)?)?; + let driver = SubscribeDriver::new(exchange, subscription_id(), tx, rx_status); + + Ok(Self::Subscribe { req, driver }) + } + _ => { + error!("Opcode not handled: {:?}", opcode); + Err(ErrorCode::InvalidOpcode.into()) + } } } + + pub async fn start(&mut self) -> Result { + let started = match self { + Self::Read { req, driver } => { + driver.start(req)?; + true + } + Self::Write { req, driver } => driver.start(req).await?, + Self::Invoke { req, driver } => driver.start(req).await?, + Self::Subscribe { req, driver } => { + driver.start(req)?; + true + } + }; + + Ok(started) + } + + fn status_response(tx: &mut Packet, status: IMStatusCode) -> Result<(), Error> { + tx.reset(); + tx.set_proto_id(PROTO_ID_INTERACTION_MODEL); + tx.set_proto_opcode(OpCode::StatusResponse as u8); + + let mut tw = TLVWriter::new(tx.get_writebuf()?); + + let status = StatusResp { status }; + status.to_tlv(&mut tw, TagType::Anonymous) + } +} + +async fn exchange_confirm( + exchange: &mut Exchange<'_>, + tx: &mut Packet<'_>, + rx: &mut Packet<'_>, +) -> Result { + exchange.exchange(tx, rx).await?; + + let opcode: OpCode = rx.get_proto_opcode()?; + + if opcode == OpCode::StatusResponse { + let resp = StatusResp::from_tlv(&get_root_node_struct(rx.as_slice())?)?; + Ok(resp.status) + } else { + Interaction::status_response(tx, IMStatusCode::Busy)?; // TODO + + exchange.send_complete(tx).await?; + + Err(ErrorCode::Invalid.into()) // TODO + } +} + +fn has_timed_out(epoch: Epoch, timeout: Option) -> bool { + timeout.map(|timeout| epoch() > timeout).unwrap_or(false) } diff --git a/matter/src/lib.rs b/matter/src/lib.rs index 1d7e5d4a..b80a62c4 100644 --- a/matter/src/lib.rs +++ b/matter/src/lib.rs @@ -69,6 +69,7 @@ //! Start off exploring by going to the [Matter] object. #![cfg_attr(not(feature = "std"), no_std)] #![cfg_attr(feature = "nightly", feature(async_fn_in_trait))] +#![cfg_attr(feature = "nightly", feature(impl_trait_projections))] #![cfg_attr(feature = "nightly", allow(incomplete_features))] pub mod acl; @@ -90,3 +91,22 @@ pub mod transport; pub mod utils; pub use crate::core::*; + +#[cfg(feature = "alloc")] +extern crate alloc; + +#[cfg(feature = "alloc")] +#[macro_export] +macro_rules! alloc { + ($val:expr) => { + alloc::boxed::Box::new($val) + }; +} + +#[cfg(not(feature = "alloc"))] +#[macro_export] +macro_rules! alloc { + ($val:expr) => { + $val + }; +} diff --git a/matter/src/secure_channel/case.rs b/matter/src/secure_channel/case.rs index 63d5e56e..28f4508e 100644 --- a/matter/src/secure_channel/case.rs +++ b/matter/src/secure_channel/case.rs @@ -20,30 +20,25 @@ use core::cell::RefCell; use log::{error, trace}; use crate::{ + alloc, cert::Cert, crypto::{self, KeyPair, Sha256}, error::{Error, ErrorCode}, fabric::{Fabric, FabricMgr}, - secure_channel::common::SCStatusCodes, - secure_channel::common::{self, OpCode}, + secure_channel::common::{self, OpCode, PROTO_ID_SECURE_CHANNEL}, + secure_channel::common::{complete_with_status, SCStatusCodes}, tlv::{get_root_node_struct, FromTLV, OctetStr, TLVWriter, TagType}, transport::{ + exchange::Exchange, network::Address, - proto_ctx::ProtoCtx, + packet::Packet, session::{CaseDetails, CloneData, NocCatIds, SessionMode}, }, utils::{rand::Rand, writebuf::WriteBuf}, }; -#[derive(Debug, Copy, Clone, Eq, PartialEq)] -enum State { - Sigma1Rx, - Sigma3Rx, -} - #[derive(Debug, Clone)] -pub struct CaseSession { - state: State, +struct CaseSession { peer_sessid: u16, local_sessid: u16, tt_hash: Sha256, @@ -54,11 +49,11 @@ pub struct CaseSession { } impl CaseSession { - pub fn new(peer_sessid: u16, local_sessid: u16) -> Result { + #[inline(always)] + pub fn new() -> Result { Ok(Self { - state: State::Sigma1Rx, - peer_sessid, - local_sessid, + peer_sessid: 0, + local_sessid: 0, tt_hash: Sha256::new()?, shared_secret: [0; crypto::ECDH_SHARED_SECRET_LEN_BYTES], our_pub_key: [0; crypto::EC_POINT_LEN_BYTES], @@ -79,39 +74,50 @@ impl<'a> Case<'a> { Self { fabric_mgr, rand } } - pub fn casesigma3_handler( + pub async fn handle( &mut self, - ctx: &mut ProtoCtx, - ) -> Result<(bool, Option), Error> { - let mut case_session = ctx - .exch_ctx - .exch - .take_case_session() - .ok_or(ErrorCode::InvalidState)?; - if case_session.state != State::Sigma1Rx { - Err(ErrorCode::Invalid)?; - } - case_session.state = State::Sigma3Rx; + exchange: &mut Exchange<'_>, + rx: &mut Packet<'_>, + tx: &mut Packet<'_>, + ) -> Result<(), Error> { + let mut session = alloc!(CaseSession::new()?); + + self.handle_casesigma1(exchange, rx, tx, &mut session) + .await?; + self.handle_casesigma3(exchange, rx, tx, &mut session).await + } + + #[allow(clippy::await_holding_refcell_ref)] + async fn handle_casesigma3( + &mut self, + exchange: &mut Exchange<'_>, + rx: &mut Packet<'_>, + tx: &mut Packet<'_>, + case_session: &mut CaseSession, + ) -> Result<(), Error> { + rx.check_proto_opcode(OpCode::CASESigma3 as _)?; let fabric_mgr = self.fabric_mgr.borrow(); let fabric = fabric_mgr.get_fabric(case_session.local_fabric_idx)?; if fabric.is_none() { - common::create_sc_status_report( - ctx.tx, + drop(fabric_mgr); + complete_with_status( + exchange, + tx, common::SCStatusCodes::NoSharedTrustRoots, None, - )?; - ctx.exch_ctx.exch.close(); - return Ok((true, None)); + ) + .await?; + return Ok(()); } // Safe to unwrap here let fabric = fabric.unwrap(); - let root = get_root_node_struct(ctx.rx.as_slice())?; + let root = get_root_node_struct(rx.as_slice())?; let encrypted = root.find_tag(1)?.slice()?; - let mut decrypted: [u8; 800] = [0; 800]; + let mut decrypted = alloc!([0; 800]); if encrypted.len() > decrypted.len() { error!("Data too large"); Err(ErrorCode::NoSpace)?; @@ -119,22 +125,29 @@ impl<'a> Case<'a> { let decrypted = &mut decrypted[..encrypted.len()]; decrypted.copy_from_slice(encrypted); - let len = Case::get_sigma3_decryption(fabric.ipk.op_key(), &case_session, decrypted)?; + let len = Case::get_sigma3_decryption(fabric.ipk.op_key(), case_session, decrypted)?; let decrypted = &decrypted[..len]; let root = get_root_node_struct(decrypted)?; let d = Sigma3Decrypt::from_tlv(&root)?; - let initiator_noc = Cert::new(d.initiator_noc.0)?; + let initiator_noc = alloc!(Cert::new(d.initiator_noc.0)?); let mut initiator_icac = None; if let Some(icac) = d.initiator_icac { - initiator_icac = Some(Cert::new(icac.0)?); + initiator_icac = Some(alloc!(Cert::new(icac.0)?)); } - if let Err(e) = Case::validate_certs(fabric, &initiator_noc, &initiator_icac) { + + #[cfg(feature = "alloc")] + let initiator_icac_mut = initiator_icac.as_deref(); + + #[cfg(not(feature = "alloc"))] + let initiator_icac_mut = initiator_icac.as_ref(); + + if let Err(e) = Case::validate_certs(fabric, &initiator_noc, initiator_icac_mut) { error!("Certificate Chain doesn't match: {}", e); - common::create_sc_status_report(ctx.tx, common::SCStatusCodes::InvalidParameter, None)?; - ctx.exch_ctx.exch.close(); - return Ok((true, None)); + complete_with_status(exchange, tx, common::SCStatusCodes::InvalidParameter, None) + .await?; + return Ok(()); } if Case::validate_sigma3_sign( @@ -142,39 +155,52 @@ impl<'a> Case<'a> { d.initiator_icac.map(|a| a.0), &initiator_noc, d.signature.0, - &case_session, + case_session, ) .is_err() { error!("Sigma3 Signature doesn't match"); - common::create_sc_status_report(ctx.tx, common::SCStatusCodes::InvalidParameter, None)?; - ctx.exch_ctx.exch.close(); - return Ok((true, None)); + complete_with_status(exchange, tx, common::SCStatusCodes::InvalidParameter, None) + .await?; + return Ok(()); } // Only now do we add this message to the TT Hash let mut peer_catids: NocCatIds = Default::default(); initiator_noc.get_cat_ids(&mut peer_catids); - case_session.tt_hash.update(ctx.rx.as_slice())?; + case_session.tt_hash.update(rx.as_slice())?; let clone_data = Case::get_session_clone_data( fabric.ipk.op_key(), fabric.get_node_id(), initiator_noc.get_node_id()?, - ctx.exch_ctx.sess.get_peer_addr(), - &case_session, + exchange.with_session(|sess| Ok(sess.get_peer_addr()))?, + case_session, &peer_catids, )?; - common::create_sc_status_report(ctx.tx, SCStatusCodes::SessionEstablishmentSuccess, None)?; - ctx.exch_ctx.exch.clear_data(); - ctx.exch_ctx.exch.close(); - Ok((true, Some(clone_data))) + // TODO: Handle NoSpace + exchange.with_session_mgr_mut(|sess_mgr| sess_mgr.clone_session(&clone_data))?; + + complete_with_status( + exchange, + tx, + SCStatusCodes::SessionEstablishmentSuccess, + None, + ) + .await } - pub fn casesigma1_handler(&mut self, ctx: &mut ProtoCtx) -> Result { - ctx.tx.set_proto_opcode(OpCode::CASESigma2 as u8); + #[allow(clippy::await_holding_refcell_ref)] + async fn handle_casesigma1( + &mut self, + exchange: &mut Exchange<'_>, + rx: &mut Packet<'_>, + tx: &mut Packet<'_>, + case_session: &mut CaseSession, + ) -> Result<(), Error> { + rx.check_proto_opcode(OpCode::CASESigma1 as _)?; - let rx_buf = ctx.rx.as_slice(); + let rx_buf = rx.as_slice(); let root = get_root_node_struct(rx_buf)?; let r = Sigma1Req::from_tlv(&root)?; @@ -184,17 +210,20 @@ impl<'a> Case<'a> { .match_dest_id(r.initiator_random.0, r.dest_id.0); if local_fabric_idx.is_err() { error!("Fabric Index mismatch"); - common::create_sc_status_report( - ctx.tx, + complete_with_status( + exchange, + tx, common::SCStatusCodes::NoSharedTrustRoots, None, - )?; - ctx.exch_ctx.exch.close(); - return Ok(true); + ) + .await?; + + return Ok(()); } - let local_sessid = ctx.exch_ctx.sess.reserve_new_sess_id(); - let mut case_session = CaseSession::new(r.initiator_sessid, local_sessid)?; + let local_sessid = exchange.with_session_mgr_mut(|mgr| Ok(mgr.get_next_sess_id()))?; + case_session.peer_sessid = r.initiator_sessid; + case_session.local_sessid = local_sessid; case_session.tt_hash.update(rx_buf)?; case_session.local_fabric_idx = local_fabric_idx?; if r.peer_pub_key.0.len() != crypto::EC_POINT_LEN_BYTES { @@ -225,52 +254,71 @@ impl<'a> Case<'a> { // Derive the Encrypted Part const MAX_ENCRYPTED_SIZE: usize = 800; - let mut encrypted: [u8; MAX_ENCRYPTED_SIZE] = [0; MAX_ENCRYPTED_SIZE]; + let mut encrypted = alloc!([0; MAX_ENCRYPTED_SIZE]); let encrypted_len = { - let mut signature = [0u8; crypto::EC_SIGNATURE_LEN_BYTES]; + let mut signature = alloc!([0u8; crypto::EC_SIGNATURE_LEN_BYTES]); let fabric_mgr = self.fabric_mgr.borrow(); let fabric = fabric_mgr.get_fabric(case_session.local_fabric_idx)?; if fabric.is_none() { - common::create_sc_status_report( - ctx.tx, + drop(fabric_mgr); + complete_with_status( + exchange, + tx, common::SCStatusCodes::NoSharedTrustRoots, None, - )?; - ctx.exch_ctx.exch.close(); - return Ok(true); + ) + .await?; + return Ok(()); } + #[cfg(feature = "alloc")] + let signature_mut = &mut *signature; + + #[cfg(not(feature = "alloc"))] + let signature_mut = &mut signature; + let sign_len = Case::get_sigma2_sign( fabric.unwrap(), &case_session.our_pub_key, &case_session.peer_pub_key, - &mut signature, + signature_mut, )?; let signature = &signature[..sign_len]; + #[cfg(feature = "alloc")] + let encrypted_mut = &mut *encrypted; + + #[cfg(not(feature = "alloc"))] + let encrypted_mut = &mut encrypted; + Case::get_sigma2_encryption( fabric.unwrap(), self.rand, &our_random, - &mut case_session, + case_session, signature, - &mut encrypted, + encrypted_mut, )? }; let encrypted = &encrypted[0..encrypted_len]; // Generate our Response Body - let mut tw = TLVWriter::new(ctx.tx.get_writebuf()?); + tx.reset(); + tx.set_proto_id(PROTO_ID_SECURE_CHANNEL); + tx.set_proto_opcode(OpCode::CASESigma2 as u8); + + let mut tw = TLVWriter::new(tx.get_writebuf()?); tw.start_struct(TagType::Anonymous)?; tw.str8(TagType::Context(1), &our_random)?; tw.u16(TagType::Context(2), local_sessid)?; tw.str8(TagType::Context(3), &case_session.our_pub_key)?; tw.str16(TagType::Context(4), encrypted)?; tw.end_container()?; - case_session.tt_hash.update(ctx.tx.as_mut_slice())?; - ctx.exch_ctx.exch.set_case_session(case_session); - Ok(true) + + case_session.tt_hash.update(tx.as_mut_slice())?; + + exchange.exchange(tx, rx).await } fn get_session_clone_data( @@ -334,7 +382,7 @@ impl<'a> Case<'a> { Ok(()) } - fn validate_certs(fabric: &Fabric, noc: &Cert, icac: &Option) -> Result<(), Error> { + fn validate_certs(fabric: &Fabric, noc: &Cert, icac: Option<&Cert>) -> Result<(), Error> { let mut verifier = noc.verify_chain_start(); if fabric.get_fabric_id() != noc.get_fabric_id()? { diff --git a/matter/src/secure_channel/common.rs b/matter/src/secure_channel/common.rs index 80fb7b51..2f00ed45 100644 --- a/matter/src/secure_channel/common.rs +++ b/matter/src/secure_channel/common.rs @@ -17,7 +17,10 @@ use num_derive::FromPrimitive; -use crate::{error::Error, transport::packet::Packet}; +use crate::{ + error::Error, + transport::{exchange::Exchange, packet::Packet}, +}; use super::status_report::{create_status_report, GeneralCode}; @@ -51,6 +54,17 @@ pub enum SCStatusCodes { SessionNotFound = 5, } +pub async fn complete_with_status( + exchange: &mut Exchange<'_>, + tx: &mut Packet<'_>, + status_code: SCStatusCodes, + proto_data: Option<&[u8]>, +) -> Result<(), Error> { + create_sc_status_report(tx, status_code, proto_data)?; + + exchange.send_complete(tx).await +} + pub fn create_sc_status_report( proto_tx: &mut Packet, status_code: SCStatusCodes, diff --git a/matter/src/secure_channel/core.rs b/matter/src/secure_channel/core.rs index 0ad17ed3..b20ea9a5 100644 --- a/matter/src/secure_channel/core.rs +++ b/matter/src/secure_channel/core.rs @@ -15,18 +15,19 @@ * limitations under the License. */ -use core::{borrow::Borrow, cell::RefCell}; +use core::borrow::Borrow; +use core::cell::RefCell; + +use log::error; use crate::{ error::*, fabric::FabricMgr, mdns::Mdns, - secure_channel::common::*, - tlv, - transport::{proto_ctx::ProtoCtx, session::CloneData}, + secure_channel::{common::*, pake::Pake}, + transport::{exchange::Exchange, packet::Packet}, utils::{epoch::Epoch, rand::Rand}, }; -use log::{error, info}; use super::{case::Case, pake::PaseMgr}; @@ -34,9 +35,10 @@ use super::{case::Case, pake::PaseMgr}; */ pub struct SecureChannel<'a> { - case: Case<'a>, pase: &'a RefCell, + fabric: &'a RefCell, mdns: &'a dyn Mdns, + rand: Rand, } impl<'a> SecureChannel<'a> { @@ -66,45 +68,34 @@ impl<'a> SecureChannel<'a> { rand: Rand, ) -> Self { Self { - case: Case::new(fabric, rand), + fabric, pase, mdns, + rand, } } - pub fn handle(&mut self, ctx: &mut ProtoCtx) -> Result<(bool, Option), Error> { - let proto_opcode: OpCode = ctx.rx.get_proto_opcode()?; - - ctx.tx.set_proto_id(PROTO_ID_SECURE_CHANNEL); - info!("Received Opcode: {:?}", proto_opcode); - info!("Received Data:"); - tlv::print_tlv_list(ctx.rx.as_slice()); - let (reply, clone_data) = match proto_opcode { - OpCode::MRPStandAloneAck => Ok((false, None)), - OpCode::PBKDFParamRequest => self - .pase - .borrow_mut() - .pbkdfparamreq_handler(ctx) - .map(|reply| (reply, None)), - OpCode::PASEPake1 => self - .pase - .borrow_mut() - .pasepake1_handler(ctx) - .map(|reply| (reply, None)), - OpCode::PASEPake3 => self.pase.borrow_mut().pasepake3_handler(ctx, self.mdns), - OpCode::CASESigma1 => self.case.casesigma1_handler(ctx).map(|reply| (reply, None)), - OpCode::CASESigma3 => self.case.casesigma3_handler(ctx), - _ => { - error!("OpCode Not Handled: {:?}", proto_opcode); + pub async fn handle( + &self, + exchange: &mut Exchange<'_>, + rx: &mut Packet<'_>, + tx: &mut Packet<'_>, + ) -> Result<(), Error> { + match rx.get_proto_opcode()? { + OpCode::PBKDFParamRequest => { + Pake::new(self.pase) + .handle(exchange, rx, tx, self.mdns) + .await + } + OpCode::CASESigma1 => { + Case::new(self.fabric, self.rand) + .handle(exchange, rx, tx) + .await + } + proto_opcode => { + error!("OpCode not handled: {:?}", proto_opcode); Err(ErrorCode::InvalidOpcode.into()) } - }?; - - if reply { - info!("Sending response"); - tlv::print_tlv_list(ctx.tx.as_mut_slice()); } - - Ok((reply, clone_data)) } } diff --git a/matter/src/secure_channel/pake.rs b/matter/src/secure_channel/pake.rs index 79f7d2cd..ea2b98c0 100644 --- a/matter/src/secure_channel/pake.rs +++ b/matter/src/secure_channel/pake.rs @@ -15,36 +15,35 @@ * limitations under the License. */ -use core::{fmt::Write, time::Duration}; +use core::{cell::RefCell, fmt::Write, time::Duration}; use super::{ - common::{create_sc_status_report, SCStatusCodes}, + common::{SCStatusCodes, PROTO_ID_SECURE_CHANNEL}, spake2p::{Spake2P, VerifierData}, }; use crate::{ - crypto, + alloc, crypto, error::{Error, ErrorCode}, mdns::{Mdns, ServiceMode}, - secure_channel::common::OpCode, + secure_channel::common::{complete_with_status, OpCode}, tlv::{self, get_root_node_struct, FromTLV, OctetStr, TLVWriter, TagType, ToTLV}, transport::{ - exchange::ExchangeCtx, - network::Address, - proto_ctx::ProtoCtx, + exchange::{Exchange, ExchangeId}, + packet::Packet, session::{CloneData, SessionMode}, }, utils::{epoch::Epoch, rand::Rand}, }; use log::{error, info}; -#[allow(clippy::large_enum_variant)] -enum PaseMgrState { - Enabled(Pake, heapless::String<16>), - Disabled, +struct PaseSession { + mdns_service_name: heapless::String<16>, + verifier: VerifierData, } pub struct PaseMgr { - state: PaseMgrState, + session: Option, + timeout: Option, epoch: Epoch, rand: Rand, } @@ -53,14 +52,15 @@ impl PaseMgr { #[inline(always)] pub fn new(epoch: Epoch, rand: Rand) -> Self { Self { - state: PaseMgrState::Disabled, + session: None, + timeout: None, epoch, rand, } } pub fn is_pase_session_enabled(&self) -> bool { - matches!(&self.state, PaseMgrState::Enabled(_, _)) + self.session.is_some() } pub fn enable_pase_session( @@ -80,62 +80,24 @@ impl PaseMgr { &mdns_service_name, ServiceMode::Commissionable(discriminator), )?; - self.state = PaseMgrState::Enabled( - Pake::new(verifier, self.epoch, self.rand), + + self.session = Some(PaseSession { mdns_service_name, - ); + verifier, + }); Ok(()) } pub fn disable_pase_session(&mut self, mdns: &dyn Mdns) -> Result<(), Error> { - if let PaseMgrState::Enabled(_, mdns_service_name) = &self.state { - mdns.remove(mdns_service_name)?; + if let Some(session) = self.session.as_ref() { + mdns.remove(&session.mdns_service_name)?; } - self.state = PaseMgrState::Disabled; + self.session = None; Ok(()) } - - /// If the PASE Session is enabled, execute the closure, - /// if not enabled, generate SC Status Report - fn if_enabled(&mut self, ctx: &mut ProtoCtx, f: F) -> Result, Error> - where - F: FnOnce(&mut Pake, &mut ProtoCtx) -> Result, - { - if let PaseMgrState::Enabled(pake, _) = &mut self.state { - let data = f(pake, ctx)?; - - Ok(Some(data)) - } else { - error!("PASE Not enabled"); - create_sc_status_report(ctx.tx, SCStatusCodes::InvalidParameter, None)?; - Ok(None) - } - } - - pub fn pbkdfparamreq_handler(&mut self, ctx: &mut ProtoCtx) -> Result { - ctx.tx.set_proto_opcode(OpCode::PBKDFParamResponse as u8); - self.if_enabled(ctx, |pake, ctx| pake.handle_pbkdfparamrequest(ctx))?; - Ok(true) - } - - pub fn pasepake1_handler(&mut self, ctx: &mut ProtoCtx) -> Result { - ctx.tx.set_proto_opcode(OpCode::PASEPake2 as u8); - self.if_enabled(ctx, |pake, ctx| pake.handle_pasepake1(ctx))?; - Ok(true) - } - - pub fn pasepake3_handler( - &mut self, - ctx: &mut ProtoCtx, - mdns: &dyn Mdns, - ) -> Result<(bool, Option), Error> { - let clone_data = self.if_enabled(ctx, |pake, ctx| pake.handle_pasepake3(ctx))?; - self.disable_pase_session(mdns)?; - Ok((true, clone_data.flatten())) - } } // This file basically deals with the handlers for the PASE secure channel protocol @@ -147,96 +109,65 @@ const PASE_DISCARD_TIMEOUT_SECS: Duration = Duration::from_secs(60); const SPAKE2_SESSION_KEYS_INFO: [u8; 11] = *b"SessionKeys"; -struct SessionData { +struct Timeout { start_time: Duration, - exch_id: u16, - peer_addr: Address, - spake2p: Spake2P, -} - -impl SessionData { - fn is_sess_expired(&self, epoch: Epoch) -> Result { - Ok(epoch() - self.start_time > PASE_DISCARD_TIMEOUT_SECS) - } + exch_id: ExchangeId, } -#[allow(clippy::large_enum_variant)] -enum PakeState { - Idle, - InProgress(SessionData), -} - -impl PakeState { - const fn new() -> Self { - Self::Idle - } - - fn take(&mut self) -> Result { - let new = core::mem::replace(self, PakeState::Idle); - if let PakeState::InProgress(s) = new { - Ok(s) - } else { - Err(ErrorCode::InvalidSignature.into()) - } - } - - fn is_idle(&self) -> bool { - core::mem::discriminant(self) == core::mem::discriminant(&PakeState::Idle) - } - - fn take_sess_data(&mut self, exch_ctx: &ExchangeCtx) -> Result { - let sd = self.take()?; - if sd.exch_id != exch_ctx.exch.get_id() || sd.peer_addr != exch_ctx.sess.get_peer_addr() { - Err(ErrorCode::InvalidState.into()) - } else { - Ok(sd) - } - } - - fn make_in_progress(&mut self, epoch: Epoch, spake2p: Spake2P, exch_ctx: &ExchangeCtx) { - *self = PakeState::InProgress(SessionData { +impl Timeout { + fn new(exchange: &Exchange, epoch: Epoch) -> Self { + Self { start_time: epoch(), - spake2p, - exch_id: exch_ctx.exch.get_id(), - peer_addr: exch_ctx.sess.get_peer_addr(), - }); + exch_id: exchange.id().clone(), + } } - fn set_sess_data(&mut self, sd: SessionData) { - *self = PakeState::InProgress(sd); + fn is_sess_expired(&self, epoch: Epoch) -> bool { + epoch() - self.start_time > PASE_DISCARD_TIMEOUT_SECS } } -impl Default for PakeState { - fn default() -> Self { - Self::new() - } +pub struct Pake<'a> { + pase: &'a RefCell, } -struct Pake { - verifier: VerifierData, - state: PakeState, - epoch: Epoch, - rand: Rand, -} - -impl Pake { - pub fn new(verifier: VerifierData, epoch: Epoch, rand: Rand) -> Self { +impl<'a> Pake<'a> { + pub const fn new(pase: &'a RefCell) -> Self { // TODO: Can any PBKDF2 calculation be pre-computed here - Self { - verifier, - state: PakeState::new(), - epoch, - rand, - } + Self { pase } + } + + pub async fn handle( + &mut self, + exchange: &mut Exchange<'_>, + rx: &mut Packet<'_>, + tx: &mut Packet<'_>, + mdns: &dyn Mdns, + ) -> Result<(), Error> { + let mut spake2p = alloc!(Spake2P::new()); + + self.handle_pbkdfparamrequest(exchange, rx, tx, &mut spake2p) + .await?; + self.handle_pasepake1(exchange, rx, tx, &mut spake2p) + .await?; + self.handle_pasepake3(exchange, rx, tx, mdns, &mut spake2p) + .await } #[allow(non_snake_case)] - pub fn handle_pasepake3(&mut self, ctx: &mut ProtoCtx) -> Result, Error> { - let mut sd = self.state.take_sess_data(&ctx.exch_ctx)?; + async fn handle_pasepake3( + &mut self, + exchange: &mut Exchange<'_>, + rx: &mut Packet<'_>, + tx: &mut Packet<'_>, + mdns: &dyn Mdns, + spake2p: &mut Spake2P, + ) -> Result<(), Error> { + rx.check_proto_opcode(OpCode::PASEPake3 as _)?; + self.update_timeout(exchange, tx, true).await?; - let cA = extract_pasepake_1_or_3_params(ctx.rx.as_slice())?; - let (status_code, ke) = sd.spake2p.handle_cA(cA); + let cA = extract_pasepake_1_or_3_params(rx.as_slice())?; + let (status_code, ke) = spake2p.handle_cA(cA); let clone_data = if status_code == SCStatusCodes::SessionEstablishmentSuccess { // Get the keys @@ -246,7 +177,7 @@ impl Pake { .map_err(|_x| ErrorCode::NoSpace)?; // Create a session - let data = sd.spake2p.get_app_data(); + let data = spake2p.get_app_data(); let peer_sessid: u16 = (data & 0xffff) as u16; let local_sessid: u16 = ((data >> 16) & 0xffff) as u16; let mut clone_data = CloneData::new( @@ -254,7 +185,7 @@ impl Pake { 0, peer_sessid, local_sessid, - ctx.exch_ctx.sess.get_peer_addr(), + exchange.with_session(|sess| Ok(sess.get_peer_addr()))?, SessionMode::Pase, ); clone_data.dec_key.copy_from_slice(&session_keys[0..16]); @@ -269,48 +200,70 @@ impl Pake { None }; - create_sc_status_report(ctx.tx, status_code, None)?; - ctx.exch_ctx.exch.close(); - Ok(clone_data) + if let Some(clone_data) = clone_data { + // TODO: Handle NoSpace + exchange.with_session_mgr_mut(|sess_mgr| sess_mgr.clone_session(&clone_data))?; + + self.pase.borrow_mut().disable_pase_session(mdns)?; + } + + complete_with_status(exchange, tx, status_code, None).await?; + + Ok(()) } #[allow(non_snake_case)] - pub fn handle_pasepake1(&mut self, ctx: &mut ProtoCtx) -> Result<(), Error> { - let mut sd = self.state.take_sess_data(&ctx.exch_ctx)?; + #[allow(clippy::await_holding_refcell_ref)] + async fn handle_pasepake1( + &mut self, + exchange: &mut Exchange<'_>, + rx: &mut Packet<'_>, + tx: &mut Packet<'_>, + spake2p: &mut Spake2P, + ) -> Result<(), Error> { + rx.check_proto_opcode(OpCode::PASEPake1 as _)?; + self.update_timeout(exchange, tx, false).await?; + + let pase = self.pase.borrow(); + let session = pase.session.as_ref().ok_or(ErrorCode::NoSession)?; - let pA = extract_pasepake_1_or_3_params(ctx.rx.as_slice())?; + let pA = extract_pasepake_1_or_3_params(rx.as_slice())?; let mut pB: [u8; 65] = [0; 65]; let mut cB: [u8; 32] = [0; 32]; - sd.spake2p.start_verifier(&self.verifier)?; - sd.spake2p.handle_pA(pA, &mut pB, &mut cB, self.rand)?; + spake2p.start_verifier(&session.verifier)?; + spake2p.handle_pA(pA, &mut pB, &mut cB, pase.rand)?; - let mut tw = TLVWriter::new(ctx.tx.get_writebuf()?); + // Generate response + tx.reset(); + tx.set_proto_id(PROTO_ID_SECURE_CHANNEL); + tx.set_proto_opcode(OpCode::PASEPake2 as u8); + + let mut tw = TLVWriter::new(tx.get_writebuf()?); let resp = Pake1Resp { pb: OctetStr(&pB), cb: OctetStr(&cB), }; resp.to_tlv(&mut tw, TagType::Anonymous)?; - self.state.set_sess_data(sd); - - Ok(()) + drop(pase); + exchange.exchange(tx, rx).await } - pub fn handle_pbkdfparamrequest(&mut self, ctx: &mut ProtoCtx) -> Result<(), Error> { - if !self.state.is_idle() { - let sd = self.state.take()?; - if sd.is_sess_expired(self.epoch)? { - info!("Previous session expired, clearing it"); - self.state = PakeState::Idle; - } else { - info!("Previous session in-progress, denying new request"); - // little-endian timeout (here we've hardcoded 500ms) - create_sc_status_report(ctx.tx, SCStatusCodes::Busy, Some(&[0xf4, 0x01]))?; - return Ok(()); - } - } + #[allow(clippy::await_holding_refcell_ref)] + async fn handle_pbkdfparamrequest( + &mut self, + exchange: &mut Exchange<'_>, + rx: &mut Packet<'_>, + tx: &mut Packet<'_>, + spake2p: &mut Spake2P, + ) -> Result<(), Error> { + rx.check_proto_opcode(OpCode::PBKDFParamRequest as _)?; + self.update_timeout(exchange, tx, true).await?; - let root = tlv::get_root_node(ctx.rx.as_slice())?; + let pase = self.pase.borrow(); + let session = pase.session.as_ref().ok_or(ErrorCode::NoSession)?; + + let root = tlv::get_root_node(rx.as_slice())?; let a = PBKDFParamReq::from_tlv(&root)?; if a.passcode_id != 0 { error!("Can't yet handle passcode_id != 0"); @@ -318,15 +271,18 @@ impl Pake { } let mut our_random: [u8; 32] = [0; 32]; - (self.rand)(&mut our_random); + (self.pase.borrow().rand)(&mut our_random); - let local_sessid = ctx.exch_ctx.sess.reserve_new_sess_id(); + let local_sessid = exchange.with_session_mgr_mut(|mgr| Ok(mgr.get_next_sess_id()))?; let spake2p_data: u32 = ((local_sessid as u32) << 16) | a.initiator_ssid as u32; - let mut spake2p = Spake2P::new(); spake2p.set_app_data(spake2p_data); // Generate response - let mut tw = TLVWriter::new(ctx.tx.get_writebuf()?); + tx.reset(); + tx.set_proto_id(PROTO_ID_SECURE_CHANNEL); + tx.set_proto_opcode(OpCode::PBKDFParamResponse as u8); + + let mut tw = TLVWriter::new(tx.get_writebuf()?); let mut resp = PBKDFParamResp { init_random: a.initiator_random, our_random: OctetStr(&our_random), @@ -335,18 +291,76 @@ impl Pake { }; if !a.has_params { let params_resp = PBKDFParamRespParams { - count: self.verifier.count, - salt: OctetStr(&self.verifier.salt), + count: session.verifier.count, + salt: OctetStr(&session.verifier.salt), }; resp.params = Some(params_resp); } resp.to_tlv(&mut tw, TagType::Anonymous)?; - spake2p.set_context(ctx.rx.as_slice(), ctx.tx.as_mut_slice())?; - self.state - .make_in_progress(self.epoch, spake2p, &ctx.exch_ctx); + spake2p.set_context(rx.as_slice(), tx.as_mut_slice())?; - Ok(()) + drop(pase); + + exchange.exchange(tx, rx).await + } + + #[allow(clippy::await_holding_refcell_ref)] + async fn update_timeout( + &mut self, + exchange: &mut Exchange<'_>, + tx: &mut Packet<'_>, + new: bool, + ) -> Result<(), Error> { + self.check_session(exchange, tx).await?; + + let mut pase = self.pase.borrow_mut(); + + if pase + .timeout + .as_ref() + .map(|sd| sd.is_sess_expired(pase.epoch)) + .unwrap_or(false) + { + pase.timeout = None; + } + + let status = if let Some(sd) = pase.timeout.as_mut() { + if &sd.exch_id != exchange.id() { + info!("Other PAKE session in progress"); + Some(SCStatusCodes::Busy) + } else { + None + } + } else if new { + None + } else { + error!("PAKE session not found or expired"); + Some(SCStatusCodes::SessionNotFound) + }; + + if let Some(status) = status { + drop(pase); + + complete_with_status(exchange, tx, status, None).await + } else { + pase.timeout = Some(Timeout::new(exchange, pase.epoch)); + + Ok(()) + } + } + + async fn check_session( + &mut self, + exchange: &mut Exchange<'_>, + tx: &mut Packet<'_>, + ) -> Result<(), Error> { + if self.pase.borrow().session.is_none() { + error!("PASE not enabled"); + complete_with_status(exchange, tx, SCStatusCodes::InvalidParameter, None).await + } else { + Ok(()) + } } } diff --git a/matter/src/transport/core.rs b/matter/src/transport/core.rs index 1b169eec..2a54b4a3 100644 --- a/matter/src/transport/core.rs +++ b/matter/src/transport/core.rs @@ -15,237 +15,413 @@ * limitations under the License. */ -use log::info; - -use crate::{error::*, CommissioningData, Matter}; +use core::{borrow::Borrow, cell::RefCell}; -use crate::secure_channel::common::PROTO_ID_SECURE_CHANNEL; -use crate::secure_channel::core::SecureChannel; -use crate::transport::mrp::ReliableMessage; -use crate::transport::{exchange, network::Address, packet::Packet}; - -use super::proto_ctx::ProtoCtx; -use super::session::CloneData; +use crate::{error::ErrorCode, secure_channel::common::OpCode, Matter}; +use embassy_futures::select::select; +use embassy_time::{Duration, Timer}; +use log::info; -enum RecvState { - New, - OpenExchange, - AddSession(CloneData), - EvictSession, - EvictSession2(CloneData), - Ack, +use crate::{ + error::Error, secure_channel::common::PROTO_ID_SECURE_CHANNEL, transport::packet::Packet, +}; + +use super::{ + exchange::{ + Exchange, ExchangeCtr, ExchangeCtx, ExchangeId, ExchangeState, Notification, Role, + MAX_EXCHANGES, + }, + mrp::ReliableMessage, + session::SessionMgr, +}; + +#[derive(Debug)] +enum OpCodeDescriptor { + SecureChannel(OpCode), + InteractionModel(crate::interaction_model::core::OpCode), + Unknown(u8), } -pub enum RecvAction<'r, 'p> { - Send(Address, &'r [u8]), - Interact(ProtoCtx<'r, 'p>), +impl From for OpCodeDescriptor { + fn from(value: u8) -> Self { + if let Some(opcode) = num::FromPrimitive::from_u8(value) { + Self::SecureChannel(opcode) + } else if let Some(opcode) = num::FromPrimitive::from_u8(value) { + Self::InteractionModel(opcode) + } else { + Self::Unknown(value) + } + } } -pub struct RecvCompletion<'r, 'a> { - transport: &'r mut Transport<'a>, - rx: Packet<'r>, - tx: Packet<'r>, - state: RecvState, +pub struct Transport<'a> { + matter: &'a Matter<'a>, + pub(crate) exchanges: RefCell>, + pub(crate) send_notification: Notification, + pub(crate) persist_notification: Notification, + pub session_mgr: RefCell, } -impl<'r, 'a> RecvCompletion<'r, 'a> { - pub fn next_action(&mut self) -> Result>, Error> { - loop { - // Polonius will remove the need for unsafe one day - let this = unsafe { (self as *mut RecvCompletion).as_mut().unwrap() }; +impl<'a> Transport<'a> { + #[inline(always)] + pub fn new(matter: &'a Matter<'a>) -> Self { + let epoch = matter.epoch; + let rand = matter.rand; - if let Some(action) = this.maybe_next_action()? { - return Ok(action); - } + Self { + matter, + exchanges: RefCell::new(heapless::Vec::new()), + send_notification: Notification::new(), + persist_notification: Notification::new(), + session_mgr: RefCell::new(SessionMgr::new(epoch, rand)), } } - fn maybe_next_action(&mut self) -> Result>>, Error> { - self.transport.exch_mgr.purge(); - self.tx.reset(); + pub fn matter(&self) -> &'a Matter<'a> { + self.matter + } - let (state, next) = match core::mem::replace(&mut self.state, RecvState::New) { - RecvState::New => { - self.rx.plain_hdr_decode()?; - (RecvState::OpenExchange, None) - } - RecvState::OpenExchange => match self.transport.exch_mgr.recv(&mut self.rx) { - Ok(Some(exch_ctx)) => { - if self.rx.get_proto_id() == PROTO_ID_SECURE_CHANNEL { - let mut proto_ctx = ProtoCtx::new(exch_ctx, &self.rx, &mut self.tx); - - let mut secure_channel = SecureChannel::new(self.transport.matter); - - let (reply, clone_data) = secure_channel.handle(&mut proto_ctx)?; - - let state = if let Some(clone_data) = clone_data { - RecvState::AddSession(clone_data) - } else { - RecvState::Ack - }; - - if reply { - if proto_ctx.send()? { - ( - state, - Some(Some(RecvAction::Send(self.tx.peer, self.tx.as_slice()))), - ) - } else { - (state, None) - } - } else { - (state, None) - } - } else { - let proto_ctx = ProtoCtx::new(exch_ctx, &self.rx, &mut self.tx); - - (RecvState::Ack, Some(Some(RecvAction::Interact(proto_ctx)))) - } + pub async fn initiate(&self, _fabric_id: u64, _node_id: u64) -> Result, Error> { + unimplemented!() + } + + pub fn process_rx<'r>( + &'r self, + construction_notification: &'r Notification, + src_rx: &mut Packet<'_>, + ) -> Result>, Error> { + self.purge()?; + + let mut exchanges = self.exchanges.borrow_mut(); + let (ctx, new) = match self.post_recv(&mut exchanges, src_rx) { + Ok((ctx, new)) => (ctx, new), + Err(e) => match e.code() { + ErrorCode::Duplicate => { + self.send_notification.signal(()); + return Ok(None); } - Ok(None) => (RecvState::Ack, None), - Err(e) => match e.code() { - ErrorCode::Duplicate => (RecvState::Ack, None), - ErrorCode::NoSpace => (RecvState::EvictSession, None), - _ => Err(e)?, - }, + _ => Err(e)?, }, - RecvState::AddSession(clone_data) => { - match self.transport.exch_mgr.add_session(&clone_data) { - Ok(_) => (RecvState::Ack, None), - Err(e) => match e.code() { - ErrorCode::NoSpace => (RecvState::EvictSession2(clone_data), None), - _ => Err(e)?, - }, - } - } - RecvState::EvictSession => { - if self.transport.exch_mgr.evict_session(&mut self.tx)? { - ( - RecvState::OpenExchange, - Some(Some(RecvAction::Send(self.tx.peer, self.tx.as_slice()))), - ) - } else { - (RecvState::EvictSession, None) + }; + + src_rx.log("Got packet"); + + if src_rx.proto.is_ack() { + if new { + Err(ErrorCode::Invalid)?; + } else { + let state = &mut ctx.state; + + match state { + ExchangeState::ExchangeRecv { + tx_acknowledged, .. + } => { + *tx_acknowledged = true; + } + ExchangeState::CompleteAcknowledge { notification, .. } => { + unsafe { notification.as_ref() }.unwrap().signal(()); + ctx.state = ExchangeState::Closed; + } + _ => { + // TODO: Error handling + todo!() + } } + + self.notify_changed(); } - RecvState::EvictSession2(clone_data) => { - if self.transport.exch_mgr.evict_session(&mut self.tx)? { - ( - RecvState::AddSession(clone_data), - Some(Some(RecvAction::Send(self.tx.peer, self.tx.as_slice()))), - ) - } else { - (RecvState::EvictSession2(clone_data), None) + } + + if new { + let constructor = ExchangeCtr { + exchange: Exchange { + id: ctx.id.clone(), + transport: self, + notification: Notification::new(), + }, + construction_notification, + }; + + self.notify_changed(); + + Ok(Some(constructor)) + } else if src_rx.proto.proto_id == PROTO_ID_SECURE_CHANNEL + && src_rx.proto.proto_opcode == OpCode::MRPStandAloneAck as u8 + { + // Standalone ack, do nothing + Ok(None) + } else { + let state = &mut ctx.state; + + match state { + ExchangeState::ExchangeRecv { + rx, notification, .. + } => { + let rx = unsafe { rx.as_mut() }.unwrap(); + rx.load(src_rx)?; + + unsafe { notification.as_ref() }.unwrap().signal(()); + *state = ExchangeState::Active; } - } - RecvState::Ack => { - if let Some(exch_id) = self.transport.exch_mgr.pending_ack() { - info!("Sending MRP Standalone ACK for exch {}", exch_id); - - ReliableMessage::prepare_ack(exch_id, &mut self.tx); - - if self.transport.exch_mgr.send(exch_id, &mut self.tx)? { - ( - RecvState::Ack, - Some(Some(RecvAction::Send(self.tx.peer, self.tx.as_slice()))), - ) - } else { - (RecvState::Ack, None) - } - } else { - (RecvState::Ack, Some(None)) + _ => { + // TODO: Error handling + todo!() } } - }; - self.state = state; - Ok(next) + self.notify_changed(); + + Ok(None) + } } -} -enum NotifyState {} + pub async fn wait_construction( + &self, + construction_notification: &Notification, + src_rx: &Packet<'_>, + exchange_id: &ExchangeId, + ) -> Result<(), Error> { + construction_notification.wait().await; -pub enum NotifyAction<'r, 'p> { - Send(&'r [u8]), - Notify(ProtoCtx<'r, 'p>), -} + let mut exchanges = self.exchanges.borrow_mut(); -pub struct NotifyCompletion<'r, 'a> { - // TODO - _transport: &'r mut Transport<'a>, - _rx: Packet<'r>, - _tx: Packet<'r>, - _state: NotifyState, -} + let ctx = Self::get(&mut exchanges, exchange_id).unwrap(); -impl<'r, 'a> NotifyCompletion<'r, 'a> { - pub fn next_action(&mut self) -> Result>, Error> { - loop { - // Polonius will remove the need for unsafe one day - let this = unsafe { (self as *mut NotifyCompletion).as_mut().unwrap() }; + let state = &mut ctx.state; + + match state { + ExchangeState::Construction { rx, notification } => { + let rx = unsafe { rx.as_mut() }.unwrap(); + rx.load(src_rx)?; - if let Some(action) = this.maybe_next_action()? { - return Ok(action); + unsafe { notification.as_ref() }.unwrap().signal(()); + *state = ExchangeState::Active; } + _ => unreachable!(), } + + Ok(()) } - fn maybe_next_action(&mut self) -> Result>>, Error> { - Ok(Some(None)) // TODO: Future + pub async fn wait_tx(&self) -> Result<(), Error> { + select( + self.send_notification.wait(), + Timer::after(Duration::from_millis(100)), + ) + .await; + + Ok(()) } -} -pub struct Transport<'a> { - matter: &'a Matter<'a>, - exch_mgr: exchange::ExchangeMgr, -} + pub async fn pull_tx(&self, dest_tx: &mut Packet<'_>) -> Result { + self.purge()?; -impl<'a> Transport<'a> { - #[inline(always)] - pub fn new(matter: &'a Matter<'a>) -> Self { - let epoch = matter.epoch; - let rand = matter.rand; + let mut exchanges = self.exchanges.borrow_mut(); - Self { - matter, - exch_mgr: exchange::ExchangeMgr::new(epoch, rand), + let ctx = exchanges.iter_mut().find(|ctx| { + matches!( + &ctx.state, + ExchangeState::Acknowledge { .. } + | ExchangeState::ExchangeSend { .. } + // | ExchangeState::ExchangeRecv { + // tx_acknowledged: false, + // .. + // } + | ExchangeState::Complete { .. } // | ExchangeState::CompleteAcknowledge { .. } + ) || ctx.mrp.is_ack_ready(*self.matter.borrow()) + }); + + if let Some(ctx) = ctx { + self.notify_changed(); + + let state = &mut ctx.state; + + let send = match state { + ExchangeState::Acknowledge { notification } => { + ReliableMessage::prepare_ack(ctx.id.id, dest_tx); + + unsafe { notification.as_ref() }.unwrap().signal(()); + *state = ExchangeState::Active; + + true + } + ExchangeState::ExchangeSend { + tx, + rx, + notification, + } => { + let tx = unsafe { tx.as_ref() }.unwrap(); + dest_tx.load(tx)?; + + *state = ExchangeState::ExchangeRecv { + _tx: tx, + tx_acknowledged: false, + rx: *rx, + notification: *notification, + }; + + true + } + // ExchangeState::ExchangeRecv { .. } => { + // // TODO: Re-send the tx package if due + // false + // } + ExchangeState::Complete { tx, notification } => { + let tx = unsafe { tx.as_ref() }.unwrap(); + dest_tx.load(tx)?; + + *state = ExchangeState::CompleteAcknowledge { + _tx: tx as *const _, + notification: *notification, + }; + + true + } + // ExchangeState::CompleteAcknowledge { .. } => { + // // TODO: Re-send the tx package if due + // false + // } + _ => { + ReliableMessage::prepare_ack(ctx.id.id, dest_tx); + true + } + }; + + if send { + dest_tx.log("Sending packet"); + + self.pre_send(ctx, dest_tx)?; + self.notify_changed(); + + return Ok(true); + } } + + Ok(false) } - pub fn matter(&self) -> &Matter<'a> { - self.matter + fn purge(&self) -> Result<(), Error> { + loop { + let mut exchanges = self.exchanges.borrow_mut(); + + if let Some(index) = exchanges.iter_mut().enumerate().find_map(|(index, ctx)| { + matches!(ctx.state, ExchangeState::Closed).then_some(index) + }) { + exchanges.swap_remove(index); + } else { + break; + } + } + + Ok(()) } - pub fn start(&mut self, dev_comm: CommissioningData, buf: &mut [u8]) -> Result<(), Error> { - info!("Starting Matter transport"); + fn post_recv<'r>( + &self, + exchanges: &'r mut heapless::Vec, + rx: &mut Packet<'_>, + ) -> Result<(&'r mut ExchangeCtx, bool), Error> { + rx.plain_hdr_decode()?; + + // Get the session - if self.matter().start_comissioning(dev_comm, buf)? { - info!("Comissioning started"); + let mut session_mgr = self.session_mgr.borrow_mut(); + + let sess_index = session_mgr.post_recv(rx)?; + let session = session_mgr.mut_by_index(sess_index).unwrap(); + + // Decrypt the message + session.recv(self.matter.epoch, rx)?; + + // Get the exchange + // TODO: Handle out of space + let (exch, new) = Self::register( + exchanges, + ExchangeId::load(rx), + Role::complementary(rx.proto.is_initiator()), + // We create a new exchange, only if the peer is the initiator + rx.proto.is_initiator(), + )?; + + // Message Reliability Protocol + exch.mrp.recv(rx, self.matter.epoch)?; + + Ok((exch, new)) + } + + fn pre_send(&self, ctx: &mut ExchangeCtx, tx: &mut Packet) -> Result<(), Error> { + let mut session_mgr = self.session_mgr.borrow_mut(); + let sess_index = session_mgr + .get( + ctx.id.session_id.id, + ctx.id.session_id.peer_addr, + ctx.id.session_id.peer_nodeid, + ctx.id.session_id.is_encrypted, + ) + .ok_or(ErrorCode::NoSession)?; + + let session = session_mgr.mut_by_index(sess_index).unwrap(); + + tx.proto.exch_id = ctx.id.id; + if ctx.role == Role::Initiator { + tx.proto.set_initiator(); } - Ok(()) + session.pre_send(tx)?; + ctx.mrp.pre_send(tx)?; + session_mgr.send(sess_index, tx) + } + + fn register( + exchanges: &mut heapless::Vec, + id: ExchangeId, + role: Role, + create_new: bool, + ) -> Result<(&mut ExchangeCtx, bool), Error> { + let exchange_index = exchanges + .iter_mut() + .enumerate() + .find_map(|(index, exchange)| (exchange.id == id).then_some(index)); + + if let Some(exchange_index) = exchange_index { + let exchange = &mut exchanges[exchange_index]; + if exchange.role == role { + Ok((exchange, false)) + } else { + Err(ErrorCode::NoExchange.into()) + } + } else if create_new { + info!("Creating new exchange: {:?}", id); + + let exchange = ExchangeCtx { + id, + role, + mrp: ReliableMessage::new(), + state: ExchangeState::Active, + }; + + exchanges.push(exchange).map_err(|_| ErrorCode::NoSpace)?; + + Ok((exchanges.iter_mut().next_back().unwrap(), true)) + } else { + Err(ErrorCode::NoExchange.into()) + } } - pub fn recv<'r>( - &'r mut self, - addr: Address, - rx_buf: &'r mut [u8], - tx_buf: &'r mut [u8], - ) -> RecvCompletion<'r, 'a> { - let mut rx = Packet::new_rx(rx_buf); - let tx = Packet::new_tx(tx_buf); - - rx.peer = addr; - - RecvCompletion { - transport: self, - rx, - tx, - state: RecvState::New, + pub(crate) fn get<'r>( + exchanges: &'r mut heapless::Vec, + id: &ExchangeId, + ) -> Option<&'r mut ExchangeCtx> { + exchanges.iter_mut().find(|exchange| exchange.id == *id) + } + + pub fn notify_changed(&self) { + if self.matter().is_changed() { + self.persist_notification.signal(()); } } - pub fn notify(&mut self, _tx: &mut Packet) -> Result { - Ok(false) + pub async fn wait_changed(&self) { + self.persist_notification.wait().await } } diff --git a/matter/src/transport/exchange.rs b/matter/src/transport/exchange.rs index 5dbb1bba..fbe3d7aa 100644 --- a/matter/src/transport/exchange.rs +++ b/matter/src/transport/exchange.rs @@ -1,625 +1,320 @@ -/* - * - * Copyright (c) 2020-2022 Project CHIP Authors - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -use core::fmt; -use core::time::Duration; -use log::{error, info, trace}; -use owo_colors::OwoColorize; - -use crate::error::{Error, ErrorCode}; -use crate::interaction_model::core::{ResumeReadReq, ResumeSubscribeReq}; -use crate::secure_channel; -use crate::secure_channel::case::CaseSession; -use crate::utils::epoch::Epoch; -use crate::utils::rand::Rand; - -use heapless::LinearMap; - -use super::session::CloneData; -use super::{mrp::ReliableMessage, packet::Packet, session::SessionHandle, session::SessionMgr}; - -pub struct ExchangeCtx<'a> { - pub exch: &'a mut Exchange, - pub sess: SessionHandle<'a>, - pub epoch: Epoch, -} +use embassy_sync::blocking_mutex::raw::NoopRawMutex; -impl<'a> ExchangeCtx<'a> { - pub fn send(&mut self, tx: &mut Packet) -> Result { - self.exch.send(tx, &mut self.sess) - } -} +use crate::{ + acl::Accessor, + error::{Error, ErrorCode}, + Matter, +}; + +use super::{ + core::Transport, + mrp::ReliableMessage, + network::Address, + packet::Packet, + session::{Session, SessionMgr}, +}; + +pub const MAX_EXCHANGES: usize = 8; + +pub type Notification = embassy_sync::signal::Signal; #[derive(Debug, PartialEq, Eq, Copy, Clone, Default)] -pub enum Role { +pub(crate) enum Role { #[default] Initiator = 0, Responder = 1, } -#[derive(Debug, PartialEq, Default)] -enum State { - /// The exchange is open and active - #[default] - Open, - /// The exchange is closed, but keys are active since retransmissions/acks may be pending - Close, - /// The exchange is terminated, keys are destroyed, no communication can happen - Terminate, +impl Role { + pub fn complementary(is_initiator: bool) -> Self { + if is_initiator { + Self::Responder + } else { + Self::Initiator + } + } } -// Instead of just doing an Option<>, we create some special handling -// where the commonly used higher layer data store does't have to do a Box -#[derive(Default)] -pub enum DataOption { - CaseSession(CaseSession), - Time(Duration), - SuspendedReadReq(ResumeReadReq), - SubscriptionId(u32), - SuspendedSubscibeReq(ResumeSubscribeReq), - #[default] - None, +#[derive(Debug)] +pub(crate) struct ExchangeCtx { + pub(crate) id: ExchangeId, + pub(crate) role: Role, + pub(crate) mrp: ReliableMessage, + pub(crate) state: ExchangeState, } -#[derive(Default)] -pub struct Exchange { - id: u16, - sess_idx: usize, - role: Role, - state: State, - mrp: ReliableMessage, - // Currently I see this primarily used in PASE and CASE. If that is the limited use - // of this, we might move this into a separate data structure, so as not to burden - // all 'exchanges'. - data: DataOption, +#[derive(Debug, Clone)] +pub(crate) enum ExchangeState { + Construction { + rx: *mut Packet<'static>, + notification: *const Notification, + }, + Active, + Acknowledge { + notification: *const Notification, + }, + ExchangeSend { + tx: *const Packet<'static>, + rx: *mut Packet<'static>, + notification: *const Notification, + }, + ExchangeRecv { + _tx: *const Packet<'static>, + tx_acknowledged: bool, + rx: *mut Packet<'static>, + notification: *const Notification, + }, + Complete { + tx: *const Packet<'static>, + notification: *const Notification, + }, + CompleteAcknowledge { + _tx: *const Packet<'static>, + notification: *const Notification, + }, + Closed, } -impl Exchange { - pub fn new(id: u16, sess_idx: usize, role: Role) -> Exchange { - Exchange { - id, - sess_idx, - role, - state: State::Open, - mrp: ReliableMessage::new(), - ..Default::default() - } - } +pub struct ExchangeCtr<'a> { + pub(crate) exchange: Exchange<'a>, + pub(crate) construction_notification: &'a Notification, +} - pub fn terminate(&mut self) { - self.data = DataOption::None; - self.state = State::Terminate; +impl<'a> ExchangeCtr<'a> { + pub const fn id(&self) -> &ExchangeId { + self.exchange.id() } - pub fn close(&mut self) { - self.data = DataOption::None; - self.state = State::Close; - } + pub async fn get(mut self, rx: &mut Packet<'_>) -> Result, Error> { + let construction_notification = self.construction_notification; - pub fn is_state_open(&self) -> bool { - self.state == State::Open - } + self.exchange.with_ctx_mut(move |exchange, ctx| { + if !matches!(ctx.state, ExchangeState::Active) { + Err(ErrorCode::NoExchange)?; + } - pub fn is_purgeable(&self) -> bool { - // No Users, No pending ACKs/Retrans - self.state == State::Terminate || (self.state == State::Close && self.mrp.is_empty()) - } + let rx: &'static mut Packet<'static> = unsafe { core::mem::transmute(rx) }; + let notification: &'static Notification = + unsafe { core::mem::transmute(&exchange.notification) }; - pub fn get_id(&self) -> u16 { - self.id - } + ctx.state = ExchangeState::Construction { rx, notification }; - pub fn get_role(&self) -> Role { - self.role - } + construction_notification.signal(()); - pub fn clear_data(&mut self) { - self.data = DataOption::None; - } + Ok(()) + })?; - pub fn set_case_session(&mut self, session: CaseSession) { - self.data = DataOption::CaseSession(session); - } + self.exchange.notification.wait().await; - pub fn get_case_session(&mut self) -> Option<&mut CaseSession> { - if let DataOption::CaseSession(session) = &mut self.data { - Some(session) - } else { - None - } - } - - pub fn take_case_session(&mut self) -> Option { - let old = core::mem::replace(&mut self.data, DataOption::None); - if let DataOption::CaseSession(session) = old { - Some(session) - } else { - self.data = old; - None - } + Ok(self.exchange) } +} - pub fn set_suspended_read_req(&mut self, req: ResumeReadReq) { - self.data = DataOption::SuspendedReadReq(req); - } +#[derive(Debug, Clone, Eq, PartialEq)] +pub struct ExchangeId { + pub id: u16, + pub session_id: SessionId, +} - pub fn take_suspended_read_req(&mut self) -> Option { - let old = core::mem::replace(&mut self.data, DataOption::None); - if let DataOption::SuspendedReadReq(req) = old { - Some(req) - } else { - self.data = old; - None +impl ExchangeId { + pub fn load(rx: &Packet) -> Self { + Self { + id: rx.proto.exch_id, + session_id: SessionId::load(rx), } } +} +#[derive(Debug, Clone, Eq, PartialEq)] +pub struct SessionId { + pub id: u16, + pub peer_addr: Address, + pub peer_nodeid: Option, + pub is_encrypted: bool, +} - pub fn set_subscription_id(&mut self, id: u32) { - self.data = DataOption::SubscriptionId(id); - } - - pub fn take_subscription_id(&mut self) -> Option { - let old = core::mem::replace(&mut self.data, DataOption::None); - if let DataOption::SubscriptionId(id) = old { - Some(id) - } else { - self.data = old; - None +impl SessionId { + pub fn load(rx: &Packet) -> Self { + Self { + id: rx.plain.sess_id, + peer_addr: rx.peer, + peer_nodeid: rx.plain.get_src_u64(), + is_encrypted: rx.plain.is_encrypted(), } } +} +pub struct Exchange<'a> { + pub(crate) id: ExchangeId, + pub(crate) transport: &'a Transport<'a>, + pub(crate) notification: Notification, +} - pub fn set_suspended_subscribe_req(&mut self, req: ResumeSubscribeReq) { - self.data = DataOption::SuspendedSubscibeReq(req); +impl<'a> Exchange<'a> { + pub const fn id(&self) -> &ExchangeId { + &self.id } - pub fn take_suspended_subscribe_req(&mut self) -> Option { - let old = core::mem::replace(&mut self.data, DataOption::None); - if let DataOption::SuspendedSubscibeReq(req) = old { - Some(req) - } else { - self.data = old; - None - } + pub fn matter(&self) -> &Matter<'a> { + self.transport.matter() } - pub fn set_data_time(&mut self, expiry_ts: Option) { - if let Some(t) = expiry_ts { - self.data = DataOption::Time(t); - } + pub fn transport(&self) -> &Transport<'a> { + self.transport } - pub fn get_data_time(&self) -> Option { - match self.data { - DataOption::Time(t) => Some(t), - _ => None, - } + pub fn accessor(&self) -> Result, Error> { + self.with_session(|sess| { + Ok(Accessor::for_session( + sess, + &self.transport.matter().acl_mgr, + )) + }) } - pub(crate) fn send( - &mut self, - tx: &mut Packet, - session: &mut SessionHandle, - ) -> Result { - if self.state == State::Terminate { - info!("Skipping tx for terminated exchange {}", self.id); - return Ok(false); - } - - trace!("payload: {:x?}", tx.as_slice()); - info!( - "{} with proto id: {} opcode: {}, tlv:\n", - "Sending".blue(), - tx.get_proto_id(), - tx.get_proto_raw_opcode(), - ); - - //print_tlv_list(tx.as_slice()); + pub fn with_session_mut(&self, f: F) -> Result + where + F: FnOnce(&mut Session) -> Result, + { + self.with_ctx(|_self, ctx| { + let mut session_mgr = _self.transport.session_mgr.borrow_mut(); - tx.proto.exch_id = self.id; - if self.role == Role::Initiator { - tx.proto.set_initiator(); - } + let sess_index = session_mgr + .get( + ctx.id.session_id.id, + ctx.id.session_id.peer_addr, + ctx.id.session_id.peer_nodeid, + ctx.id.session_id.is_encrypted, + ) + .ok_or(ErrorCode::NoSession)?; - session.pre_send(tx)?; - self.mrp.pre_send(tx)?; - session.send(tx)?; - - Ok(true) + f(session_mgr.mut_by_index(sess_index).unwrap()) + }) } -} -impl fmt::Display for Exchange { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - write!( - f, - "exch_id: {:?}, sess_index: {}, role: {:?}, mrp: {:?}, state: {:?}", - self.id, self.sess_idx, self.role, self.mrp, self.state - ) + pub fn with_session(&self, f: F) -> Result + where + F: FnOnce(&Session) -> Result, + { + self.with_session_mut(|sess| f(sess)) } -} -pub fn get_role(is_initiator: bool) -> Role { - if is_initiator { - Role::Initiator - } else { - Role::Responder + pub fn with_session_mgr_mut(&self, f: F) -> Result + where + F: FnOnce(&mut SessionMgr) -> Result, + { + let mut session_mgr = self.transport.session_mgr.borrow_mut(); + + f(&mut session_mgr) } -} -pub fn get_complementary_role(is_initiator: bool) -> Role { - if is_initiator { - Role::Responder - } else { - Role::Initiator + pub async fn initiate(&mut self, fabric_id: u64, node_id: u64) -> Result, Error> { + self.transport.initiate(fabric_id, node_id).await } -} -const MAX_EXCHANGES: usize = 8; + pub async fn acknowledge(&mut self) -> Result<(), Error> { + let wait = self.with_ctx_mut(|_self, ctx| { + if !matches!(ctx.state, ExchangeState::Active) { + Err(ErrorCode::NoExchange)?; + } -pub struct ExchangeMgr { - // keys: exch-id - exchanges: LinearMap, - sess_mgr: SessionMgr, - epoch: Epoch, -} + if ctx.mrp.is_empty() { + Ok(false) + } else { + ctx.state = ExchangeState::Acknowledge { + notification: &_self.notification as *const _, + }; + _self.transport.send_notification.signal(()); -pub const MAX_MRP_ENTRIES: usize = 4; + Ok(true) + } + })?; -impl ExchangeMgr { - #[inline(always)] - pub fn new(epoch: Epoch, rand: Rand) -> Self { - Self { - sess_mgr: SessionMgr::new(epoch, rand), - exchanges: LinearMap::new(), - epoch, + if wait { + self.notification.wait().await; } - } - pub fn get_sess_mgr(&mut self) -> &mut SessionMgr { - &mut self.sess_mgr + Ok(()) } - pub fn _get_with_id( - exchanges: &mut LinearMap, - exch_id: u16, - ) -> Option<&mut Exchange> { - exchanges.get_mut(&exch_id) - } - - pub fn get_with_id(&mut self, exch_id: u16) -> Option<&mut Exchange> { - ExchangeMgr::_get_with_id(&mut self.exchanges, exch_id) - } - - fn _get( - exchanges: &mut LinearMap, - sess_idx: usize, - id: u16, - role: Role, - create_new: bool, - ) -> Result<&mut Exchange, Error> { - // I don't prefer that we scan the list twice here (once for contains_key and other) - if !exchanges.contains_key(&(id)) { - if create_new { - // If an exchange doesn't exist, create a new one - info!("Creating new exchange"); - let e = Exchange::new(id, sess_idx, role); - if exchanges.insert(id, e).is_err() { - Err(ErrorCode::NoSpace)?; - } - } else { - Err(ErrorCode::NoSpace)?; - } - } + pub async fn exchange(&mut self, tx: &Packet<'_>, rx: &mut Packet<'_>) -> Result<(), Error> { + let tx: &Packet<'static> = unsafe { core::mem::transmute(tx) }; + let rx: &mut Packet<'static> = unsafe { core::mem::transmute(rx) }; - // At this point, we would either have inserted the record if 'create_new' was set - // or it existed already - if let Some(result) = exchanges.get_mut(&id) { - if result.get_role() == role && sess_idx == result.sess_idx { - Ok(result) - } else { - Err(ErrorCode::NoExchange.into()) + self.with_ctx_mut(|_self, ctx| { + if !matches!(ctx.state, ExchangeState::Active) { + Err(ErrorCode::NoExchange)?; } - } else { - error!("This should never happen"); - Err(ErrorCode::NoSpace.into()) - } - } - /// The Exchange Mgr receive is like a big processing function - pub fn recv(&mut self, rx: &mut Packet) -> Result, Error> { - // Get the session - let index = self.sess_mgr.post_recv(rx)?; - let mut session = self.sess_mgr.get_session_handle(index); - - // Decrypt the message - session.recv(self.epoch, rx)?; - - // Get the exchange - let exch = ExchangeMgr::_get( - &mut self.exchanges, - index, - rx.proto.exch_id, - get_complementary_role(rx.proto.is_initiator()), - // We create a new exchange, only if the peer is the initiator - rx.proto.is_initiator(), - )?; - - // Message Reliability Protocol - exch.mrp.recv(rx, self.epoch)?; - - if exch.is_state_open() { - Ok(Some(ExchangeCtx { - exch, - sess: session, - epoch: self.epoch, - })) - } else { - // Instead of an error, we send None here, because it is likely that - // we just processed an acknowledgement that cleared the exchange - Ok(None) - } - } + ctx.state = ExchangeState::ExchangeSend { + tx: tx as *const _, + rx: rx as *mut _, + notification: &_self.notification as *const _, + }; + _self.transport.send_notification.signal(()); - pub fn send(&mut self, exch_id: u16, tx: &mut Packet) -> Result { - let exchange = - ExchangeMgr::_get_with_id(&mut self.exchanges, exch_id).ok_or(ErrorCode::NoExchange)?; - let mut session = self.sess_mgr.get_session_handle(exchange.sess_idx); - exchange.send(tx, &mut session) - } + Ok(()) + })?; - pub fn purge(&mut self) { - let mut to_purge: LinearMap = LinearMap::new(); + self.notification.wait().await; - for (exch_id, exchange) in self.exchanges.iter() { - if exchange.is_purgeable() { - let _ = to_purge.insert(*exch_id, ()); - } - } - for (exch_id, _) in to_purge.iter() { - self.exchanges.remove(exch_id); - } + Ok(()) } - pub fn pending_ack(&mut self) -> Option { - self.exchanges - .iter() - .find(|(_, exchange)| exchange.mrp.is_ack_ready(self.epoch)) - .map(|(exch_id, _)| *exch_id) + pub async fn complete(mut self, tx: &Packet<'_>) -> Result<(), Error> { + self.send_complete(tx).await } - pub fn evict_session(&mut self, tx: &mut Packet) -> Result { - if let Some(index) = self.sess_mgr.get_session_for_eviction() { - info!("Sessions full, vacating session with index: {}", index); - // If we enter here, we have an LRU session that needs to be reclaimed - // As per the spec, we need to send a CLOSE here - - let mut session = self.sess_mgr.get_session_handle(index); - secure_channel::common::create_sc_status_report( - tx, - secure_channel::common::SCStatusCodes::CloseSession, - None, - )?; - - if let Some((_, exchange)) = - self.exchanges.iter_mut().find(|(_, e)| e.sess_idx == index) - { - // Send Close_session on this exchange, and then close the session - // Should this be done for all exchanges? - error!("Sending Close Session"); - exchange.send(tx, &mut session)?; - // TODO: This wouldn't actually send it out, because 'transport' isn't owned yet. - } + pub async fn send_complete(&mut self, tx: &Packet<'_>) -> Result<(), Error> { + let tx: &Packet<'static> = unsafe { core::mem::transmute(tx) }; - let remove_exchanges: heapless::Vec = self - .exchanges - .iter() - .filter_map(|(eid, e)| { - if e.sess_idx == index { - Some(*eid) - } else { - None - } - }) - .collect(); - info!( - "Terminating the following exchanges: {:?}", - remove_exchanges - ); - for exch_id in remove_exchanges { - // Remove from exchange list - self.exchanges.remove(&exch_id); + self.with_ctx_mut(|_self, ctx| { + if !matches!(ctx.state, ExchangeState::Active) { + Err(ErrorCode::NoExchange)?; } - self.sess_mgr.remove(index); + ctx.state = ExchangeState::Complete { + tx: tx as *const _, + notification: &_self.notification as *const _, + }; + _self.transport.send_notification.signal(()); - Ok(true) - } else { - Ok(false) - } - } + Ok(()) + })?; - pub fn add_session(&mut self, clone_data: &CloneData) -> Result { - let sess_idx = self.sess_mgr.clone_session(clone_data)?; + self.notification.wait().await; - Ok(self.sess_mgr.get_session_handle(sess_idx)) + Ok(()) } -} -impl fmt::Display for ExchangeMgr { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - writeln!(f, "{{ Session Mgr: {},", self.sess_mgr)?; - writeln!(f, " Exchanges: [")?; - for s in &self.exchanges { - writeln!(f, "{{ {}, }},", s.1)?; - } - writeln!(f, " ]")?; - write!(f, "}}") - } -} + fn with_ctx(&self, f: F) -> Result + where + F: FnOnce(&Self, &ExchangeCtx) -> Result, + { + let mut exchanges = self.transport.exchanges.borrow_mut(); -#[cfg(test)] -#[allow(clippy::bool_assert_comparison)] -mod tests { - use crate::{ - error::ErrorCode, - transport::{ - network::Address, - session::{CloneData, SessionMode}, - }, - utils::{epoch::dummy_epoch, rand::dummy_rand}, - }; - - use super::{ExchangeMgr, Role}; - - #[test] - fn test_purge() { - let mut mgr = ExchangeMgr::new(dummy_epoch, dummy_rand); - let _ = ExchangeMgr::_get(&mut mgr.exchanges, 1, 2, Role::Responder, true).unwrap(); - let _ = ExchangeMgr::_get(&mut mgr.exchanges, 1, 3, Role::Responder, true).unwrap(); - - mgr.purge(); - assert_eq!( - ExchangeMgr::_get(&mut mgr.exchanges, 1, 2, Role::Responder, false).is_ok(), - true - ); - assert_eq!( - ExchangeMgr::_get(&mut mgr.exchanges, 1, 3, Role::Responder, false).is_ok(), - true - ); - - // Close e1 - let e1 = ExchangeMgr::_get(&mut mgr.exchanges, 1, 2, Role::Responder, false).unwrap(); - e1.close(); - mgr.purge(); - assert_eq!( - ExchangeMgr::_get(&mut mgr.exchanges, 1, 2, Role::Responder, false).is_ok(), - false - ); - assert_eq!( - ExchangeMgr::_get(&mut mgr.exchanges, 1, 3, Role::Responder, false).is_ok(), - true - ); - } + let exchange = Transport::get(&mut exchanges, &self.id).ok_or(ErrorCode::NoExchange)?; // TODO - fn get_clone_data(peer_sess_id: u16, local_sess_id: u16) -> CloneData { - CloneData::new( - 12341234, - 43211234, - peer_sess_id, - local_sess_id, - Address::default(), - SessionMode::Pase, - ) + f(self, exchange) } - fn fill_sessions(mgr: &mut ExchangeMgr, count: usize) { - let mut local_sess_id = 1; - let mut peer_sess_id = 100; - for _ in 1..count { - let clone_data = get_clone_data(peer_sess_id, local_sess_id); - match mgr.add_session(&clone_data) { - Ok(s) => assert_eq!(peer_sess_id, s.get_peer_sess_id()), - Err(e) => { - if e.code() == ErrorCode::NoSpace { - break; - } else { - panic!("Could not create sessions"); - } - } - } - local_sess_id += 1; - peer_sess_id += 1; - } + fn with_ctx_mut(&mut self, f: F) -> Result + where + F: FnOnce(&mut Self, &mut ExchangeCtx) -> Result, + { + let mut exchanges = self.transport.exchanges.borrow_mut(); + + let exchange = Transport::get(&mut exchanges, &self.id).ok_or(ErrorCode::NoExchange)?; // TODO + + f(self, exchange) } +} - #[cfg(feature = "std")] - #[test] - /// We purposefuly overflow the sessions - /// and when the overflow happens, we confirm that - /// - The sessions are evicted in LRU - /// - The exchanges associated with those sessions are evicted too - fn test_sess_evict() { - use crate::transport::packet::{Packet, MAX_TX_BUF_SIZE}; - use crate::transport::session::MAX_SESSIONS; - - let mut mgr = ExchangeMgr::new(crate::utils::epoch::sys_epoch, dummy_rand); - - fill_sessions(&mut mgr, MAX_SESSIONS + 1); - // Sessions are now full from local session id 1 to 16 - - // Create exchanges for sessions 2 (i.e. session index 1) and 3 (session index 2) - // Exchange IDs are 20 and 30 respectively - let _ = ExchangeMgr::_get(&mut mgr.exchanges, 1, 20, Role::Responder, true).unwrap(); - let _ = ExchangeMgr::_get(&mut mgr.exchanges, 2, 30, Role::Responder, true).unwrap(); - - // Confirm that session ids 1 to MAX_SESSIONS exists - for i in 1..(MAX_SESSIONS + 1) { - assert_eq!(mgr.sess_mgr.get_with_id(i as u16).is_none(), false); - } - // Confirm that the exchanges are around - assert_eq!(mgr.get_with_id(20).is_none(), false); - assert_eq!(mgr.get_with_id(30).is_none(), false); - let mut old_local_sess_id = 1; - let mut new_local_sess_id = 100; - let mut new_peer_sess_id = 200; - - for i in 1..(MAX_SESSIONS + 1) { - // Now purposefully overflow the sessions by adding another session - let result = mgr.add_session(&get_clone_data(new_peer_sess_id, new_local_sess_id)); - assert!(matches!( - result.map_err(|e| e.code()), - Err(ErrorCode::NoSpace) - )); - - let mut buf = [0; MAX_TX_BUF_SIZE]; - let tx = &mut Packet::new_tx(&mut buf); - let evicted = mgr.evict_session(tx).unwrap(); - assert!(evicted); - - let session = mgr - .add_session(&get_clone_data(new_peer_sess_id, new_local_sess_id)) - .unwrap(); - assert_eq!(session.get_peer_sess_id(), new_peer_sess_id); - - // This should have evicted session with local sess_id - assert_eq!(mgr.sess_mgr.get_with_id(old_local_sess_id).is_none(), true); - - new_local_sess_id += 1; - new_peer_sess_id += 1; - old_local_sess_id += 1; - - match i { - 1 => { - // Both exchanges should exist - assert_eq!(mgr.get_with_id(20).is_none(), false); - assert_eq!(mgr.get_with_id(30).is_none(), false); - } - 2 => { - // Exchange 20 would have been evicted - assert_eq!(mgr.get_with_id(20).is_none(), true); - assert_eq!(mgr.get_with_id(30).is_none(), false); - } - 3 => { - // Exchange 20 and 30 would have been evicted - assert_eq!(mgr.get_with_id(20).is_none(), true); - assert_eq!(mgr.get_with_id(30).is_none(), true); - } - _ => {} - } - } - // println!("Session mgr {}", mgr.sess_mgr); +impl<'a> Drop for Exchange<'a> { + fn drop(&mut self) { + let _ = self.with_ctx_mut(|_self, ctx| { + ctx.state = ExchangeState::Closed; + _self.transport.send_notification.signal(()); + + Ok(()) + }); } } diff --git a/matter/src/transport/mod.rs b/matter/src/transport/mod.rs index a219f165..6c5601e7 100644 --- a/matter/src/transport/mod.rs +++ b/matter/src/transport/mod.rs @@ -23,7 +23,7 @@ pub mod network; pub mod packet; pub mod pipe; pub mod plain_hdr; -pub mod proto_ctx; pub mod proto_hdr; +pub mod runner; pub mod session; pub mod udp; diff --git a/matter/src/transport/proto_ctx.rs b/matter/src/transport/proto_ctx.rs deleted file mode 100644 index b7374eca..00000000 --- a/matter/src/transport/proto_ctx.rs +++ /dev/null @@ -1,41 +0,0 @@ -/* - * - * Copyright (c) 2020-2022 Project CHIP Authors - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -use crate::error::Error; - -use super::exchange::ExchangeCtx; -use super::packet::Packet; - -/// This is the context in which a receive packet is being processed -pub struct ProtoCtx<'a, 'b> { - /// This is the exchange context, that includes the exchange and the session - pub exch_ctx: ExchangeCtx<'a>, - /// This is the received buffer for this transaction - pub rx: &'a Packet<'b>, - /// This is the transmit buffer for this transaction - pub tx: &'a mut Packet<'b>, -} - -impl<'a, 'b> ProtoCtx<'a, 'b> { - pub fn new(exch_ctx: ExchangeCtx<'a>, rx: &'a Packet<'b>, tx: &'a mut Packet<'b>) -> Self { - Self { exch_ctx, rx, tx } - } - - pub fn send(&mut self) -> Result { - self.exch_ctx.exch.send(self.tx, &mut self.exch_ctx.sess) - } -} diff --git a/matter/src/transport/runner.rs b/matter/src/transport/runner.rs new file mode 100644 index 00000000..f94e819e --- /dev/null +++ b/matter/src/transport/runner.rs @@ -0,0 +1,392 @@ +/* + * + * Copyright (c) 2020-2022 Project CHIP Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +use core::{mem::MaybeUninit, pin::pin}; + +use crate::{ + alloc, + data_model::{core::DataModel, objects::DataModelHandler}, + interaction_model::core::PROTO_ID_INTERACTION_MODEL, + transport::network::{Address, IpAddr, Ipv6Addr, SocketAddr}, + CommissioningData, Matter, +}; +use embassy_futures::select::{select, select3, select_slice, Either}; +use embassy_sync::{blocking_mutex::raw::NoopRawMutex, channel::Channel}; +use log::{error, info, warn}; + +use crate::{ + error::Error, + secure_channel::{common::PROTO_ID_SECURE_CHANNEL, core::SecureChannel}, + transport::packet::{Packet, MAX_RX_BUF_SIZE}, + utils::select::EitherUnwrap, +}; + +use super::{ + core::Transport, + exchange::{ExchangeCtr, Notification, MAX_EXCHANGES}, + packet::{MAX_RX_STATUS_BUF_SIZE, MAX_TX_BUF_SIZE}, + pipe::{Chunk, Pipe}, + udp::UdpListener, +}; + +pub type TxBuf = MaybeUninit<[u8; MAX_TX_BUF_SIZE]>; +pub type RxBuf = MaybeUninit<[u8; MAX_RX_BUF_SIZE]>; +type SxBuf = MaybeUninit<[u8; MAX_RX_STATUS_BUF_SIZE]>; + +struct PacketPools { + tx: [TxBuf; MAX_EXCHANGES], + rx: [RxBuf; MAX_EXCHANGES], + sx: [SxBuf; MAX_EXCHANGES], +} + +impl PacketPools { + const TX_ELEM: TxBuf = MaybeUninit::uninit(); + const RX_ELEM: RxBuf = MaybeUninit::uninit(); + const SX_ELEM: SxBuf = MaybeUninit::uninit(); + + const TX_INIT: [TxBuf; MAX_EXCHANGES] = [Self::TX_ELEM; MAX_EXCHANGES]; + const RX_INIT: [RxBuf; MAX_EXCHANGES] = [Self::RX_ELEM; MAX_EXCHANGES]; + const SX_INIT: [SxBuf; MAX_EXCHANGES] = [Self::SX_ELEM; MAX_EXCHANGES]; + + #[inline(always)] + pub const fn new() -> Self { + Self { + tx: Self::TX_INIT, + rx: Self::RX_INIT, + sx: Self::SX_INIT, + } + } +} + +/// This struct implements an executor-agnostic option to run the Matter transport stack end-to-end. +/// +/// Since it is not possible to use executor tasks spawning in an executor-agnostic way (yet), +/// the async loops are arranged as one giant future. Therefore, the cost is a slightly slower execution +/// due to the generated future being relatively big and deeply nested. +/// +/// Users are free to implement their own async execution loop, by utilizing the `Transport` +/// struct directly with their async executor of choice. +pub struct TransportRunner<'a> { + transport: Transport<'a>, + pools: PacketPools, +} + +impl<'a> TransportRunner<'a> { + #[inline(always)] + pub fn new(matter: &'a Matter<'a>) -> Self { + Self::wrap(Transport::new(matter)) + } + + #[inline(always)] + pub const fn wrap(transport: Transport<'a>) -> Self { + Self { + transport, + pools: PacketPools::new(), + } + } + + pub fn transport(&self) -> &Transport { + &self.transport + } + + pub async fn run_udp( + &mut self, + tx_buf: &mut TxBuf, + rx_buf: &mut RxBuf, + dev_comm: CommissioningData, + handler: &H, + ) -> Result<(), Error> + where + H: DataModelHandler, + { + let udp = UdpListener::new(SocketAddr::new( + IpAddr::V6(Ipv6Addr::UNSPECIFIED), + self.transport.matter().port, + )) + .await?; + + let tx_pipe = Pipe::new(unsafe { tx_buf.assume_init_mut() }); + let rx_pipe = Pipe::new(unsafe { rx_buf.assume_init_mut() }); + + let tx_pipe = &tx_pipe; + let rx_pipe = &rx_pipe; + let udp = &udp; + + let mut tx = pin!(async move { + loop { + { + let mut data = tx_pipe.data.lock().await; + + if let Some(chunk) = data.chunk { + udp.send(chunk.addr.unwrap_udp(), &data.buf[chunk.start..chunk.end]) + .await?; + data.chunk = None; + tx_pipe.data_consumed_notification.signal(()); + } + } + + tx_pipe.data_supplied_notification.wait().await; + } + }); + + let mut rx = pin!(async move { + loop { + { + let mut data = rx_pipe.data.lock().await; + + if data.chunk.is_none() { + let (len, addr) = udp.recv(data.buf).await?; + + data.chunk = Some(Chunk { + start: 0, + end: len, + addr: Address::Udp(addr), + }); + rx_pipe.data_supplied_notification.signal(()); + } + } + + rx_pipe.data_consumed_notification.wait().await; + } + }); + + let mut run = pin!(async move { self.run(tx_pipe, rx_pipe, dev_comm, handler).await }); + + select3(&mut tx, &mut rx, &mut run).await.unwrap() + } + + pub async fn run( + &mut self, + tx_pipe: &Pipe<'_>, + rx_pipe: &Pipe<'_>, + dev_comm: CommissioningData, + handler: &H, + ) -> Result<(), Error> + where + H: DataModelHandler, + { + info!("Running Matter transport"); + + let buf = unsafe { self.pools.rx[0].assume_init_mut() }; + + if self.transport.matter().start_comissioning(dev_comm, buf)? { + info!("Comissioning started"); + } + + let construction_notification = Notification::new(); + + let mut rx = pin!(Self::handle_rx( + &self.transport, + &mut self.pools, + rx_pipe, + &construction_notification, + handler + )); + let mut tx = pin!(Self::handle_tx(&self.transport, tx_pipe)); + + select(&mut rx, &mut tx).await.unwrap() + } + + async fn handle_rx( + transport: &Transport<'_>, + pools: &mut PacketPools, + rx_pipe: &Pipe<'_>, + construction_notification: &Notification, + handler: &H, + ) -> Result<(), Error> + where + H: DataModelHandler, + { + info!("Creating queue for {} exchanges", 1); + + let channel = Channel::::new(); + + info!("Creating {} handlers", MAX_EXCHANGES); + let mut handlers = heapless::Vec::<_, MAX_EXCHANGES>::new(); + + info!("Handlers size: {}", core::mem::size_of_val(&handlers)); + + let pools = &mut *pools as *mut _; + + for index in 0..MAX_EXCHANGES { + let channel = &channel; + let handler_id = index; + + handlers + .push(async move { + loop { + let exchange_ctr: ExchangeCtr<'_> = channel.recv().await; + + info!( + "Handler {}: Got exchange {:?}", + handler_id, + exchange_ctr.id() + ); + + let result = Self::handle_exchange( + transport, + pools, + handler_id, + exchange_ctr, + handler, + ) + .await; + + if let Err(err) = result { + warn!( + "Handler {}: Exchange closed because of error: {:?}", + handler_id, err + ); + } else { + info!("Handler {}: Exchange completed", handler_id); + } + } + }) + .map_err(|_| ()) + .unwrap(); + } + + let mut rx = pin!(async { + loop { + info!("Transport: waiting for incoming packets"); + + { + let mut data = rx_pipe.data.lock().await; + + if let Some(chunk) = data.chunk { + let mut rx = alloc!(Packet::new_rx(&mut data.buf[chunk.start..chunk.end])); + rx.peer = chunk.addr; + + if let Some(exchange_ctr) = + transport.process_rx(construction_notification, &mut rx)? + { + let exchange_id = exchange_ctr.id().clone(); + + info!("Transport: got new exchange: {:?}", exchange_id); + + channel.send(exchange_ctr).await; + info!("Transport: exchange sent"); + + transport + .wait_construction(construction_notification, &rx, &exchange_id) + .await?; + + info!("Transport: exchange started"); + } + + data.chunk = None; + rx_pipe.data_consumed_notification.signal(()); + } + } + + rx_pipe.data_supplied_notification.wait().await + } + + #[allow(unreachable_code)] + Ok::<_, Error>(()) + }); + + let result = select(&mut rx, select_slice(&mut handlers)).await; + + if let Either::First(result) = result { + if let Err(e) = &result { + error!("Exitting RX loop due to an error: {:?}", e); + } + + result?; + } + + Ok(()) + } + + async fn handle_tx(transport: &Transport<'_>, tx_pipe: &Pipe<'_>) -> Result<(), Error> { + loop { + loop { + { + let mut data = tx_pipe.data.lock().await; + + if data.chunk.is_none() { + let mut tx = alloc!(Packet::new_tx(data.buf)); + + if transport.pull_tx(&mut tx).await? { + data.chunk = Some(Chunk { + start: tx.get_writebuf()?.get_start(), + end: tx.get_writebuf()?.get_tail(), + addr: tx.peer, + }); + tx_pipe.data_supplied_notification.signal(()); + } else { + break; + } + } + } + + tx_pipe.data_consumed_notification.wait().await; + } + + transport.wait_tx().await?; + } + } + + #[cfg_attr(feature = "nightly", allow(clippy::await_holding_refcell_ref))] // Fine because of the async mutex + async fn handle_exchange( + transport: &Transport<'_>, + pools: *mut PacketPools, + handler_id: usize, + exchange_ctr: ExchangeCtr<'_>, + handler: &H, + ) -> Result<(), Error> + where + H: DataModelHandler, + { + let pools = unsafe { pools.as_mut() }.unwrap(); + + let tx_buf = unsafe { pools.tx[handler_id].assume_init_mut() }; + let rx_buf = unsafe { pools.rx[handler_id].assume_init_mut() }; + let rx_status_buf = unsafe { pools.sx[handler_id].assume_init_mut() }; + + let mut rx = alloc!(Packet::new_rx(rx_buf.as_mut())); + let mut tx = alloc!(Packet::new_tx(tx_buf.as_mut())); + + let mut exchange = alloc!(exchange_ctr.get(&mut rx).await?); + + match rx.get_proto_id() { + PROTO_ID_SECURE_CHANNEL => { + let sc = SecureChannel::new(transport.matter()); + + sc.handle(&mut exchange, &mut rx, &mut tx).await?; + + transport.notify_changed(); + } + PROTO_ID_INTERACTION_MODEL => { + let dm = DataModel::new(handler); + + let mut rx_status = alloc!(Packet::new_rx(rx_status_buf)); + + dm.handle(&mut exchange, &mut rx, &mut tx, &mut rx_status) + .await?; + + transport.notify_changed(); + } + other => { + error!("Unknown Proto-ID: {}", other); + } + } + + Ok(()) + } +} diff --git a/matter/tests/common/echo_cluster.rs b/matter/tests/common/echo_cluster.rs index 5e43e18a..dd83f0ac 100644 --- a/matter/tests/common/echo_cluster.rs +++ b/matter/tests/common/echo_cluster.rs @@ -15,10 +15,9 @@ * limitations under the License. */ -use std::{ - convert::TryInto, - sync::{Arc, Mutex, Once}, -}; +use core::cell::Cell; +use core::convert::TryInto; +use std::sync::{Arc, Mutex, Once}; use matter::{ attribute_enum, command_enum, @@ -28,11 +27,9 @@ use matter::{ Quality, ATTRIBUTE_LIST, FEATURE_MAP, }, error::{Error, ErrorCode}, - interaction_model::{ - core::Transaction, - messages::ib::{attr_list_write, ListOperation}, - }, + interaction_model::messages::ib::{attr_list_write, ListOperation}, tlv::{TLVElement, TagType}, + transport::exchange::Exchange, utils::rand::Rand, }; use num_derive::FromPrimitive; @@ -132,10 +129,10 @@ pub const WRITE_LIST_MAX: usize = 5; pub struct EchoCluster { pub data_ver: Dataver, pub multiplier: u8, - pub att1: u16, - pub att2: u16, - pub att_write: u16, - pub att_custom: u32, + pub att1: Cell, + pub att2: Cell, + pub att_write: Cell, + pub att_custom: Cell, } impl EchoCluster { @@ -143,10 +140,10 @@ impl EchoCluster { Self { data_ver: Dataver::new(rand), multiplier, - att1: 0x1234, - att2: 0x5678, - att_write: ATTR_WRITE_DEFAULT_VALUE, - att_custom: ATTR_CUSTOM_VALUE, + att1: Cell::new(0x1234), + att2: Cell::new(0x5678), + att_write: Cell::new(ATTR_WRITE_DEFAULT_VALUE), + att_custom: Cell::new(ATTR_CUSTOM_VALUE), } } @@ -179,14 +176,14 @@ impl EchoCluster { } } - pub fn write(&mut self, attr: &AttrDetails, data: AttrData) -> Result<(), Error> { + pub fn write(&self, attr: &AttrDetails, data: AttrData) -> Result<(), Error> { let data = data.with_dataver(self.data_ver.get())?; match attr.attr_id.try_into()? { - Attributes::Att1(codec) => self.att1 = codec.decode(data)?, - Attributes::Att2(codec) => self.att2 = codec.decode(data)?, - Attributes::AttWrite(codec) => self.att_write = codec.decode(data)?, - Attributes::AttCustom(codec) => self.att_custom = codec.decode(data)?, + Attributes::Att1(codec) => self.att1.set(codec.decode(data)?), + Attributes::Att2(codec) => self.att2.set(codec.decode(data)?), + Attributes::AttWrite(codec) => self.att_write.set(codec.decode(data)?), + Attributes::AttCustom(codec) => self.att_custom.set(codec.decode(data)?), Attributes::AttWriteList(_) => { attr_list_write(attr, data, |op, data| self.write_attr_list(&op, data))? } @@ -198,8 +195,8 @@ impl EchoCluster { } pub fn invoke( - &mut self, - _transaction: &mut Transaction, + &self, + _exchange: &Exchange, cmd: &CmdDetails, data: &TLVElement, encoder: CmdDataEncoder, @@ -222,7 +219,7 @@ impl EchoCluster { } } - fn write_attr_list(&mut self, op: &ListOperation, data: &TLVElement) -> Result<(), Error> { + fn write_attr_list(&self, op: &ListOperation, data: &TLVElement) -> Result<(), Error> { let tc_handle = TestChecker::get().unwrap(); let mut tc = tc_handle.lock().unwrap(); match op { @@ -272,18 +269,18 @@ impl Handler for EchoCluster { EchoCluster::read(self, attr, encoder) } - fn write(&mut self, attr: &AttrDetails, data: AttrData) -> Result<(), Error> { + fn write(&self, attr: &AttrDetails, data: AttrData) -> Result<(), Error> { EchoCluster::write(self, attr, data) } fn invoke( - &mut self, - transaction: &mut Transaction, + &self, + exchange: &Exchange, cmd: &CmdDetails, data: &TLVElement, encoder: CmdDataEncoder, ) -> Result<(), Error> { - EchoCluster::invoke(self, transaction, cmd, data, encoder) + EchoCluster::invoke(self, exchange, cmd, data, encoder) } } diff --git a/matter/tests/common/handlers.rs b/matter/tests/common/handlers.rs index 7235b8a3..97de89ae 100644 --- a/matter/tests/common/handlers.rs +++ b/matter/tests/common/handlers.rs @@ -1,8 +1,6 @@ -use core::time; -use std::thread; - use log::{info, warn}; use matter::{ + error::ErrorCode, interaction_model::{ core::{IMStatusCode, OpCode}, messages::{ @@ -14,17 +12,12 @@ use matter::{ }, }, tlv::{self, FromTLV, TLVArray, ToTLV}, - transport::{ - exchange::{self, Exchange}, - session::NocCatIds, - }, - Matter, }; use super::{ attributes::assert_attr_report, commands::{assert_inv_response, ExpectedInvResp}, - im_engine::{ImEngine, ImInput, IM_ENGINE_PEER_ID}, + im_engine::{ImEngine, ImEngineHandler, ImInput, ImOutput}, }; pub enum WriteResponse<'a> { @@ -38,72 +31,71 @@ pub enum TimedInvResponse<'a> { } impl<'a> ImEngine<'a> { + pub fn read_reqs(input: &[AttrPath], expected: &[AttrResp]) { + let im = ImEngine::new_default(); + + im.add_default_acl(); + im.handle_read_reqs(&im.handler(), input, expected); + } + // Helper for handling Read Req sequences for this file pub fn handle_read_reqs( - &mut self, - peer_node_id: u64, + &self, + handler: &ImEngineHandler, input: &[AttrPath], expected: &[AttrResp], ) { - let mut out_buf = [0u8; 400]; - let received = self.gen_read_reqs_output(peer_node_id, input, None, &mut out_buf); + let mut out = heapless::Vec::<_, 1>::new(); + let received = self.gen_read_reqs_output(handler, input, None, &mut out); assert_attr_report(&received, expected) } - pub fn new_with_read_reqs( - matter: &'a Matter<'a>, - input: &[AttrPath], - expected: &[AttrResp], - ) -> Self { - let mut im = Self::new(matter); - - let mut out_buf = [0u8; 400]; - let received = im.gen_read_reqs_output(IM_ENGINE_PEER_ID, input, None, &mut out_buf); - assert_attr_report(&received, expected); - - im - } - - pub fn gen_read_reqs_output<'b>( - &mut self, - peer_node_id: u64, + pub fn gen_read_reqs_output<'c, const N: usize>( + &self, + handler: &ImEngineHandler, input: &[AttrPath], - dataver_filters: Option>, - out_buf: &'b mut [u8], - ) -> ReportDataMsg<'b> { + dataver_filters: Option>, + out: &'c mut heapless::Vec, + ) -> ReportDataMsg<'c> { let mut read_req = ReadReq::new(true).set_attr_requests(input); read_req.dataver_filters = dataver_filters; - let mut input = ImInput::new(OpCode::ReadRequest, &read_req); - input.set_peer_node_id(peer_node_id); + let input = ImInput::new(OpCode::ReadRequest, &read_req); - let (_, out_buf) = self.process(&input, out_buf); + self.process(handler, &[&input], out).unwrap(); - tlv::print_tlv_list(out_buf); - let root = tlv::get_root_node_struct(out_buf).unwrap(); + for o in &*out { + tlv::print_tlv_list(&o.data); + } + + let root = tlv::get_root_node_struct(&out[0].data).unwrap(); ReportDataMsg::from_tlv(&root).unwrap() } + pub fn write_reqs(input: &[AttrData], expected: &[AttrStatus]) { + let im = ImEngine::new_default(); + + im.add_default_acl(); + im.handle_write_reqs(&im.handler(), input, expected); + } + pub fn handle_write_reqs( - &mut self, - peer_node_id: u64, - peer_cat_ids: Option<&NocCatIds>, + &self, + handler: &ImEngineHandler, input: &[AttrData], expected: &[AttrStatus], ) { - let mut out_buf = [0u8; 400]; let write_req = WriteReq::new(false, input); - let mut input = ImInput::new(OpCode::WriteRequest, &write_req); - input.set_peer_node_id(peer_node_id); - if let Some(cat_ids) = peer_cat_ids { - input.set_cat_ids(cat_ids); - } + let input = ImInput::new(OpCode::WriteRequest, &write_req); + let mut out = heapless::Vec::<_, 1>::new(); + self.process(handler, &[&input], &mut out).unwrap(); - let (_, out_buf) = self.process(&input, &mut out_buf); + for o in &out { + tlv::print_tlv_list(&o.data); + } - tlv::print_tlv_list(out_buf); - let root = tlv::get_root_node_struct(out_buf).unwrap(); + let root = tlv::get_root_node_struct(&out[0].data).unwrap(); let mut index = 0; let response_iter = root @@ -124,194 +116,184 @@ impl<'a> ImEngine<'a> { assert_eq!(index, expected.len()); } - pub fn new_with_write_reqs( - matter: &'a Matter<'a>, - input: &[AttrData], - expected: &[AttrStatus], - ) -> Self { - let mut im = Self::new(matter); - - im.handle_write_reqs(IM_ENGINE_PEER_ID, None, input, expected); + pub fn commands(input: &[CmdData], expected: &[ExpectedInvResp]) { + let im = ImEngine::new_default(); - im + im.add_default_acl(); + im.handle_commands(&im.handler(), input, expected) } // Helper for handling Invoke Command sequences pub fn handle_commands( - &mut self, - peer_node_id: u64, + &self, + handler: &ImEngineHandler, input: &[CmdData], expected: &[ExpectedInvResp], ) { - let mut out_buf = [0u8; 400]; let req = InvReq { suppress_response: Some(false), timed_request: Some(false), inv_requests: Some(TLVArray::Slice(input)), }; - let mut input = ImInput::new(OpCode::InvokeRequest, &req); - input.set_peer_node_id(peer_node_id); + let input = ImInput::new(OpCode::InvokeRequest, &req); - let (_, out_buf) = self.process(&input, &mut out_buf); - tlv::print_tlv_list(out_buf); - let root = tlv::get_root_node_struct(out_buf).unwrap(); - let resp = msg::InvResp::from_tlv(&root).unwrap(); - assert_inv_response(&resp, expected) - } - - pub fn new_with_commands( - matter: &'a Matter<'a>, - input: &[CmdData], - expected: &[ExpectedInvResp], - ) -> Self { - let mut im = ImEngine::new(matter); + let mut out = heapless::Vec::<_, 1>::new(); + self.process(handler, &[&input], &mut out).unwrap(); - im.handle_commands(IM_ENGINE_PEER_ID, input, expected); + for o in &out { + tlv::print_tlv_list(&o.data); + } - im + let root = tlv::get_root_node_struct(&out[0].data).unwrap(); + let resp = msg::InvResp::from_tlv(&root).unwrap(); + assert_inv_response(&resp, expected) } - fn handle_timed_reqs<'b>( - &mut self, + fn gen_timed_reqs_output( + &self, + handler: &ImEngineHandler, opcode: OpCode, request: &dyn ToTLV, timeout: u16, delay: u16, - output: &'b mut [u8], - ) -> (u8, &'b [u8]) { - // Use the same exchange for all parts of the transaction - self.exch = Some(Exchange::new(1, 0, exchange::Role::Responder)); + out: &mut heapless::Vec, + ) { + let mut inp = heapless::Vec::<_, 2>::new(); + + let timed_req = TimedReq { timeout }; + let im_input = ImInput::new_delayed(OpCode::TimedRequest, &timed_req, Some(delay)); if timeout != 0 { // Send Timed Req - let mut tmp_buf = [0u8; 400]; - let timed_req = TimedReq { timeout }; - let im_input = ImInput::new(OpCode::TimedRequest, &timed_req); - let (_, out_buf) = self.process(&im_input, &mut tmp_buf); - tlv::print_tlv_list(out_buf); + inp.push(&im_input).map_err(|_| ErrorCode::NoSpace).unwrap(); } else { warn!("Skipping timed request"); } - // Process any delays - let delay = time::Duration::from_millis(delay.into()); - thread::sleep(delay); - // Send Write Req let input = ImInput::new(opcode, request); - let (resp_opcode, output) = self.process(&input, output); - (resp_opcode, output) + inp.push(&input).map_err(|_| ErrorCode::NoSpace).unwrap(); + + self.process(handler, &inp, out).unwrap(); + + drop(inp); + + for o in out { + tlv::print_tlv_list(&o.data); + } + } + + pub fn timed_write_reqs( + input: &[AttrData], + expected: &WriteResponse, + timeout: u16, + delay: u16, + ) { + let im = ImEngine::new_default(); + + im.add_default_acl(); + im.handle_timed_write_reqs(&im.handler(), input, expected, timeout, delay); } // Helper for handling Write Attribute sequences pub fn handle_timed_write_reqs( - &mut self, + &self, + handler: &ImEngineHandler, input: &[AttrData], expected: &WriteResponse, timeout: u16, delay: u16, ) { - let mut out_buf = [0u8; 400]; + let mut out = heapless::Vec::<_, 2>::new(); let write_req = WriteReq::new(false, input); - let (resp_opcode, out_buf) = self.handle_timed_reqs( + self.gen_timed_reqs_output( + handler, OpCode::WriteRequest, &write_req, timeout, delay, - &mut out_buf, + &mut out, ); - tlv::print_tlv_list(out_buf); - let root = tlv::get_root_node_struct(out_buf).unwrap(); + + let out = &out[out.len() - 1]; + let root = tlv::get_root_node_struct(&out.data).unwrap(); match expected { WriteResponse::TransactionSuccess(t) => { - assert_eq!( - num::FromPrimitive::from_u8(resp_opcode), - Some(OpCode::WriteResponse) - ); + assert_eq!(out.action, OpCode::WriteResponse); let resp = WriteResp::from_tlv(&root).unwrap(); assert_eq!(resp.write_responses, t); } WriteResponse::TransactionError => { - assert_eq!( - num::FromPrimitive::from_u8(resp_opcode), - Some(OpCode::StatusResponse) - ); + assert_eq!(out.action, OpCode::StatusResponse); let status_resp = StatusResp::from_tlv(&root).unwrap(); assert_eq!(status_resp.status, IMStatusCode::Timeout); } } } - pub fn new_with_timed_write_reqs( - matter: &'a Matter<'a>, - input: &[AttrData], - expected: &WriteResponse, + pub fn timed_commands( + input: &[CmdData], + expected: &TimedInvResponse, timeout: u16, delay: u16, - ) -> Self { - let mut im = ImEngine::new(matter); - - im.handle_timed_write_reqs(input, expected, timeout, delay); + set_timed_request: bool, + ) { + let im = ImEngine::new_default(); - im + im.add_default_acl(); + im.handle_timed_commands( + &im.handler(), + input, + expected, + timeout, + delay, + set_timed_request, + ); } // Helper for handling Invoke Command sequences pub fn handle_timed_commands( - &mut self, + &self, + handler: &ImEngineHandler, input: &[CmdData], expected: &TimedInvResponse, timeout: u16, delay: u16, set_timed_request: bool, ) { - let mut out_buf = [0u8; 400]; + let mut out = heapless::Vec::<_, 2>::new(); let req = InvReq { suppress_response: Some(false), timed_request: Some(set_timed_request), inv_requests: Some(TLVArray::Slice(input)), }; - let (resp_opcode, out_buf) = - self.handle_timed_reqs(OpCode::InvokeRequest, &req, timeout, delay, &mut out_buf); - tlv::print_tlv_list(out_buf); - let root = tlv::get_root_node_struct(out_buf).unwrap(); + self.gen_timed_reqs_output( + handler, + OpCode::InvokeRequest, + &req, + timeout, + delay, + &mut out, + ); + + let out = &out[out.len() - 1]; + let root = tlv::get_root_node_struct(&out.data).unwrap(); match expected { TimedInvResponse::TransactionSuccess(t) => { - assert_eq!( - num::FromPrimitive::from_u8(resp_opcode), - Some(OpCode::InvokeResponse) - ); + assert_eq!(out.action, OpCode::InvokeResponse); let resp = msg::InvResp::from_tlv(&root).unwrap(); assert_inv_response(&resp, t) } TimedInvResponse::TransactionError(e) => { - assert_eq!( - num::FromPrimitive::from_u8(resp_opcode), - Some(OpCode::StatusResponse) - ); + assert_eq!(out.action, OpCode::StatusResponse); let status_resp = StatusResp::from_tlv(&root).unwrap(); assert_eq!(status_resp.status, *e); } } } - - pub fn new_with_timed_commands( - matter: &'a Matter<'a>, - input: &[CmdData], - expected: &TimedInvResponse, - timeout: u16, - delay: u16, - set_timed_request: bool, - ) -> Self { - let mut im = ImEngine::new(matter); - - im.handle_timed_commands(input, expected, timeout, delay, set_timed_request); - - im - } } diff --git a/matter/tests/common/im_engine.rs b/matter/tests/common/im_engine.rs index 166b7fc8..8efb2c90 100644 --- a/matter/tests/common/im_engine.rs +++ b/matter/tests/common/im_engine.rs @@ -17,14 +17,19 @@ use crate::common::echo_cluster; use core::borrow::Borrow; +use core::future::pending; +use core::time::Duration; +use embassy_futures::select::select3; use matter::{ acl::{AclEntry, AuthMode}, data_model::{ cluster_basic_information::{self, BasicInfoConfig}, cluster_on_off::{self, OnOffCluster}, - core::DataModel, device_types::{DEV_TYPE_ON_OFF_LIGHT, DEV_TYPE_ROOT_NODE}, - objects::{Endpoint, Node, Privilege}, + objects::{ + AttrData, AttrDataEncoder, AttrDetails, Endpoint, Handler, HandlerCompat, Metadata, + Node, NonBlockingHandler, Privilege, + }, root_endpoint::{self, RootEndpointHandler}, sdm::{ admin_commissioning, @@ -36,21 +41,24 @@ use matter::{ descriptor::{self, DescriptorCluster}, }, }, - error::Error, + error::{Error, ErrorCode}, handler_chain_type, - interaction_model::core::{InteractionModel, OpCode}, - mdns::Mdns, + interaction_model::core::{OpCode, PROTO_ID_INTERACTION_MODEL}, + mdns::DummyMdns, + secure_channel::{self, common::PROTO_ID_SECURE_CHANNEL, spake2p::VerifierData}, tlv::{TLVWriter, TagType, ToTLV}, - transport::packet::Packet, transport::{ - exchange::{self, Exchange, ExchangeCtx}, - network::{Address, IpAddr, Ipv4Addr, SocketAddr}, - packet::MAX_RX_BUF_SIZE, - proto_ctx::ProtoCtx, - session::{CaseDetails, CloneData, NocCatIds, SessionMgr, SessionMode}, + exchange::Notification, + packet::{Packet, MAX_RX_BUF_SIZE, MAX_TX_BUF_SIZE}, + pipe::Pipe, + runner::TransportRunner, + }, + transport::{ + network::Address, + session::{CaseDetails, CloneData, NocCatIds, SessionMode}, }, - utils::{rand::dummy_rand, writebuf::WriteBuf}, - Matter, + utils::select::EitherUnwrap, + CommissioningData, Matter, MATTER_PORT, }; use super::echo_cluster::EchoCluster; @@ -74,183 +82,321 @@ impl DevAttDataFetcher for DummyDevAtt { } pub const IM_ENGINE_PEER_ID: u64 = 445566; +pub const IM_ENGINE_REMOTE_PEER_ID: u64 = 123456; + +const NODE: Node<'static> = Node { + id: 0, + endpoints: &[ + Endpoint { + id: 0, + clusters: &[ + descriptor::CLUSTER, + cluster_basic_information::CLUSTER, + general_commissioning::CLUSTER, + nw_commissioning::CLUSTER, + admin_commissioning::CLUSTER, + noc::CLUSTER, + access_control::CLUSTER, + echo_cluster::CLUSTER, + ], + device_type: DEV_TYPE_ROOT_NODE, + }, + Endpoint { + id: 1, + clusters: &[ + descriptor::CLUSTER, + cluster_on_off::CLUSTER, + echo_cluster::CLUSTER, + ], + device_type: DEV_TYPE_ON_OFF_LIGHT, + }, + ], +}; pub struct ImInput<'a> { action: OpCode, data: &'a dyn ToTLV, - peer_id: u64, - cat_ids: NocCatIds, + delay: Option, } impl<'a> ImInput<'a> { pub fn new(action: OpCode, data: &'a dyn ToTLV) -> Self { + Self::new_delayed(action, data, None) + } + + pub fn new_delayed(action: OpCode, data: &'a dyn ToTLV, delay: Option) -> Self { Self { action, data, - peer_id: IM_ENGINE_PEER_ID, - cat_ids: Default::default(), + delay, } } +} - pub fn set_peer_node_id(&mut self, peer: u64) { - self.peer_id = peer; +pub struct ImOutput { + pub action: OpCode, + pub data: heapless::Vec, +} + +pub struct ImEngineHandler<'a> { + handler: handler_chain_type!(OnOffCluster, EchoCluster, DescriptorCluster<'static>, EchoCluster | RootEndpointHandler<'a>), +} + +impl<'a> ImEngineHandler<'a> { + pub fn new(matter: &'a Matter<'a>) -> Self { + let handler = root_endpoint::handler(0, matter) + .chain(0, echo_cluster::ID, EchoCluster::new(2, *matter.borrow())) + .chain(1, descriptor::ID, DescriptorCluster::new(*matter.borrow())) + .chain(1, echo_cluster::ID, EchoCluster::new(3, *matter.borrow())) + .chain(1, cluster_on_off::ID, OnOffCluster::new(*matter.borrow())); + + Self { handler } } - pub fn set_cat_ids(&mut self, cat_ids: &NocCatIds) { - self.cat_ids = *cat_ids; + pub fn echo_cluster(&self, endpoint: u16) -> &EchoCluster { + match endpoint { + 0 => &self.handler.next.next.next.handler, + 1 => &self.handler.next.handler, + _ => panic!(), + } } } -pub type DmHandler<'a> = handler_chain_type!(OnOffCluster, EchoCluster, DescriptorCluster<'a>, EchoCluster | RootEndpointHandler<'a>); +impl<'a> Handler for ImEngineHandler<'a> { + fn read(&self, attr: &AttrDetails, encoder: AttrDataEncoder) -> Result<(), Error> { + self.handler.read(attr, encoder) + } + + fn write(&self, attr: &AttrDetails, data: AttrData) -> Result<(), Error> { + self.handler.write(attr, data) + } + + fn invoke( + &self, + exchange: &matter::transport::exchange::Exchange, + cmd: &matter::data_model::objects::CmdDetails, + data: &matter::tlv::TLVElement, + encoder: matter::data_model::objects::CmdDataEncoder, + ) -> Result<(), Error> { + self.handler.invoke(exchange, cmd, data, encoder) + } +} -pub fn matter(mdns: &mut dyn Mdns) -> Matter<'_> { - #[cfg(feature = "std")] - use matter::utils::epoch::sys_epoch as epoch; +impl<'a> NonBlockingHandler for ImEngineHandler<'a> {} - #[cfg(not(feature = "std"))] - use matter::utils::epoch::dummy_epoch as epoch; +impl<'a> Metadata for ImEngineHandler<'a> { + type MetadataGuard<'g> = Node<'g> where Self: 'g; - Matter::new(&BASIC_INFO, &DummyDevAtt, mdns, epoch, dummy_rand, 5540) + fn lock(&self) -> Self::MetadataGuard<'_> { + NODE + } } +static mut DNS: DummyMdns = DummyMdns; + /// An Interaction Model Engine to facilitate easy testing pub struct ImEngine<'a> { - pub matter: &'a Matter<'a>, - pub im: InteractionModel>>, - // By default, a new exchange is created for every run, if you wish to instead using a specific - // exchange, set this variable. This is helpful in situations where you have to run multiple - // actions in the same transaction (exchange) - pub exch: Option, + pub matter: Matter<'a>, + cat_ids: NocCatIds, } impl<'a> ImEngine<'a> { + pub fn new_default() -> Self { + Self::new(Default::default()) + } + /// Create the interaction model engine - pub fn new(matter: &'a Matter<'a>) -> Self { - let mut default_acl = AclEntry::new(1, Privilege::ADMIN, AuthMode::Case); - // Only allow the standard peer node id of the IM Engine - default_acl.add_subject(IM_ENGINE_PEER_ID).unwrap(); - matter.acl_mgr.borrow_mut().add(default_acl).unwrap(); - - let dm = DataModel::new( - matter.borrow(), - &Node { - id: 0, - endpoints: &[ - Endpoint { - id: 0, - clusters: &[ - descriptor::CLUSTER, - cluster_basic_information::CLUSTER, - general_commissioning::CLUSTER, - nw_commissioning::CLUSTER, - admin_commissioning::CLUSTER, - noc::CLUSTER, - access_control::CLUSTER, - echo_cluster::CLUSTER, - ], - device_type: DEV_TYPE_ROOT_NODE, - }, - Endpoint { - id: 1, - clusters: &[ - descriptor::CLUSTER, - cluster_on_off::CLUSTER, - echo_cluster::CLUSTER, - ], - device_type: DEV_TYPE_ON_OFF_LIGHT, - }, - ], - }, - root_endpoint::handler(0, matter) - .chain(0, echo_cluster::ID, EchoCluster::new(2, *matter.borrow())) - .chain(1, descriptor::ID, DescriptorCluster::new(*matter.borrow())) - .chain(1, echo_cluster::ID, EchoCluster::new(3, *matter.borrow())) - .chain(1, cluster_on_off::ID, OnOffCluster::new(*matter.borrow())), + pub fn new(cat_ids: NocCatIds) -> Self { + #[cfg(feature = "std")] + use matter::utils::epoch::sys_epoch as epoch; + + #[cfg(not(feature = "std"))] + use matter::utils::epoch::dummy_epoch as epoch; + + #[cfg(feature = "std")] + use matter::utils::rand::sys_rand as rand; + + #[cfg(not(feature = "std"))] + use matter::utils::rand::dummy_rand as rand; + + let matter = Matter::new( + &BASIC_INFO, + &DummyDevAtt, + unsafe { &mut DNS }, + epoch, + rand, + MATTER_PORT, ); - Self { - matter, - im: InteractionModel(dm), - exch: None, - } + Self { matter, cat_ids } } - pub fn echo_cluster(&self, endpoint: u16) -> &EchoCluster { - match endpoint { - 0 => &self.im.0.handler.next.next.next.handler, - 1 => &self.im.0.handler.next.handler, - _ => panic!(), - } + pub fn add_default_acl(&self) { + // Only allow the standard peer node id of the IM Engine + let mut default_acl = AclEntry::new(1, Privilege::ADMIN, AuthMode::Case); + default_acl.add_subject(IM_ENGINE_PEER_ID).unwrap(); + self.matter.acl_mgr.borrow_mut().add(default_acl).unwrap(); } - /// Run a transaction through the interaction model engine - pub fn process<'b>(&mut self, input: &ImInput, data_out: &'b mut [u8]) -> (u8, &'b [u8]) { - let mut new_exch = Exchange::new(1, 0, exchange::Role::Responder); - // Choose whether to use a new exchange, or use the one from the ImEngine configuration - let exch = self.exch.as_mut().unwrap_or(&mut new_exch); + pub fn handler(&self) -> ImEngineHandler<'_> { + ImEngineHandler::new(&self.matter) + } - let mut sess_mgr = SessionMgr::new(*self.matter.borrow(), *self.matter.borrow()); + pub fn process( + &self, + handler: &ImEngineHandler, + input: &[&ImInput], + out: &mut heapless::Vec, + ) -> Result<(), Error> { + let mut runner = TransportRunner::new(&self.matter); let clone_data = CloneData::new( - 123456, - input.peer_id, - 10, - 30, - Address::Udp(SocketAddr::new( - IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)), - 5542, - )), - SessionMode::Case(CaseDetails::new(1, &input.cat_ids)), + IM_ENGINE_REMOTE_PEER_ID, + IM_ENGINE_PEER_ID, + 1, + 1, + Address::default(), + SessionMode::Case(CaseDetails::new(1, &self.cat_ids)), ); - let sess_idx = sess_mgr.clone_session(&clone_data).unwrap(); - let sess = sess_mgr.get_session_handle(sess_idx); - let exch_ctx = ExchangeCtx { - exch, - sess, - epoch: *self.matter.borrow(), - }; - let mut rx_buf = [0; MAX_RX_BUF_SIZE]; - let mut tx_buf = [0; 1440]; // For the long read tests to run unchanged - let mut rx = Packet::new_rx(&mut rx_buf); - let mut tx = Packet::new_tx(&mut tx_buf); - // Create fake rx packet - rx.set_proto_id(0x01); - rx.set_proto_opcode(input.action as u8); - rx.peer = Address::default(); - - { - let mut buf = [0u8; 400]; - let mut wb = WriteBuf::new(&mut buf); - let mut tw = TLVWriter::new(&mut wb); - - input.data.to_tlv(&mut tw, TagType::Anonymous).unwrap(); - - let input_data = wb.as_slice(); - let in_data_len = input_data.len(); - let rx_buf = rx.as_mut_slice(); - rx_buf[..in_data_len].copy_from_slice(input_data); - rx.get_parsebuf().unwrap().set_len(in_data_len); + + let sess_idx = runner + .transport() + .session_mgr + .borrow_mut() + .clone_session(&clone_data) + .unwrap(); + + let mut tx_pipe_buf = [0; MAX_RX_BUF_SIZE]; + let mut rx_pipe_buf = [0; MAX_TX_BUF_SIZE]; + + let mut tx_buf = [0; MAX_RX_BUF_SIZE]; + let mut rx_buf = [0; MAX_TX_BUF_SIZE]; + + let tx_pipe = Pipe::new(&mut tx_buf); + let rx_pipe = Pipe::new(&mut rx_buf); + + let tx_pipe = &tx_pipe; + let rx_pipe = &rx_pipe; + let tx_pipe_buf = &mut tx_pipe_buf; + let rx_pipe_buf = &mut rx_pipe_buf; + + let handler = &handler; + let runner = &mut runner; + + let mut msg_ctr = runner + .transport() + .session_mgr + .borrow_mut() + .mut_by_index(sess_idx) + .unwrap() + .get_msg_ctr(); + + let resp_notif = Notification::new(); + let resp_notif = &resp_notif; + + embassy_futures::block_on(async move { + select3( + runner.run( + tx_pipe, + rx_pipe, + CommissioningData { + // TODO: Hard-coded for now + verifier: VerifierData::new_with_pw(123456, *self.matter.borrow()), + discriminator: 250, + }, + &HandlerCompat(handler), + ), + async move { + let mut acknowledge = false; + for ip in input { + Self::send(ip, tx_pipe_buf, rx_pipe, msg_ctr, acknowledge).await?; + resp_notif.wait().await; + + if let Some(delay) = ip.delay { + if delay > 0 { + #[cfg(feature = "std")] + std::thread::sleep(Duration::from_millis(delay as _)); + } + } + + msg_ctr += 2; + acknowledge = true; + } + + pending::<()>().await; + + Ok(()) + }, + async move { + out.clear(); + + while out.len() < input.len() { + let (len, _) = tx_pipe.recv(rx_pipe_buf).await; + + let mut rx = Packet::new_rx(&mut rx_pipe_buf[..len]); + + rx.plain_hdr_decode()?; + rx.proto_decode(IM_ENGINE_REMOTE_PEER_ID, Some(&[0u8; 16]))?; + + if rx.get_proto_id() != PROTO_ID_SECURE_CHANNEL + || rx.get_proto_opcode::()? + != secure_channel::common::OpCode::MRPStandAloneAck + { + out.push(ImOutput { + action: rx.get_proto_opcode()?, + data: heapless::Vec::from_slice(rx.as_slice()) + .map_err(|_| ErrorCode::NoSpace)?, + }) + .map_err(|_| ErrorCode::NoSpace)?; + + resp_notif.signal(()); + } + } + + Ok(()) + }, + ) + .await + .unwrap() + })?; + + Ok(()) + } + + async fn send( + input: &ImInput<'_>, + tx_buf: &mut [u8], + rx_pipe: &Pipe<'_>, + msg_ctr: u32, + acknowledge: bool, + ) -> Result<(), Error> { + let mut tx = Packet::new_tx(tx_buf); + + tx.set_proto_id(PROTO_ID_INTERACTION_MODEL); + tx.set_proto_opcode(input.action as u8); + + let mut tw = TLVWriter::new(tx.get_writebuf()?); + + input.data.to_tlv(&mut tw, TagType::Anonymous)?; + + tx.plain.ctr = msg_ctr + 1; + tx.plain.sess_id = 1; + tx.proto.set_initiator(); + + if acknowledge { + tx.proto.set_ack(msg_ctr - 1); } - let mut ctx = ProtoCtx::new(exch_ctx, &rx, &mut tx); - self.im.handle(&mut ctx).unwrap(); - let out_data_len = ctx.tx.as_slice().len(); - data_out[..out_data_len].copy_from_slice(ctx.tx.as_slice()); - let response = ctx.tx.get_proto_raw_opcode(); - (response, &data_out[..out_data_len]) + tx.proto_encode( + Address::default(), + Some(IM_ENGINE_REMOTE_PEER_ID), + IM_ENGINE_PEER_ID, + false, + Some(&[0u8; 16]), + )?; + + rx_pipe.send(Address::default(), tx.as_slice()).await; + + Ok(()) } } - -// TODO - Remove? -// // Create an Interaction Model, Data Model and run a rx/tx transaction through it -// pub fn im_engine<'a>( -// matter: &'a Matter, -// action: OpCode, -// data: &dyn ToTLV, -// data_out: &'a mut [u8], -// ) -> (DmHandler<'a>, u8, &'a mut [u8]) { -// let mut engine = ImEngine::new(matter); -// let input = ImInput::new(action, data); -// let (response, output) = engine.process(&input, data_out); -// (engine.dm.handler, response, output) -// } diff --git a/matter/tests/data_model/acl_and_dataver.rs b/matter/tests/data_model/acl_and_dataver.rs index ebb831ca..853e2cac 100644 --- a/matter/tests/data_model/acl_and_dataver.rs +++ b/matter/tests/data_model/acl_and_dataver.rs @@ -26,7 +26,6 @@ use matter::{ messages::ib::{AttrData, AttrPath, AttrResp, AttrStatus, ClusterPath, DataVersionFilter}, messages::GenericPath, }, - mdns::DummyMdns, tlv::{ElementType, TLVArray, TLVElement, TLVWriter, TagType}, }; @@ -35,7 +34,7 @@ use crate::{ common::{ attributes::*, echo_cluster::{self, ATTR_WRITE_DEFAULT_VALUE}, - im_engine::{matter, ImEngine}, + im_engine::{ImEngine, IM_ENGINE_PEER_ID}, init_env_logger, }, }; @@ -62,30 +61,28 @@ fn wc_read_attribute() { Some(echo_cluster::AttributesDiscriminants::Att1 as u32), ); - let peer = 98765; - let mut mdns = DummyMdns {}; - let matter = matter(&mut mdns); - let mut im = ImEngine::new(&matter); + let im = ImEngine::new_default(); + let handler = im.handler(); // Test1: Empty Response as no ACL matches let input = &[AttrPath::new(&wc_att1)]; let expected = &[]; - im.handle_read_reqs(peer, input, expected); + im.handle_read_reqs(&handler, input, expected); // Add ACL to allow our peer to only access endpoint 0 let mut acl = AclEntry::new(1, Privilege::ADMIN, AuthMode::Case); - acl.add_subject(peer).unwrap(); + acl.add_subject(IM_ENGINE_PEER_ID).unwrap(); acl.add_target(Target::new(Some(0), None, None)).unwrap(); im.matter.acl_mgr.borrow_mut().add(acl).unwrap(); // Test2: Only Single response as only single endpoint is allowed let input = &[AttrPath::new(&wc_att1)]; let expected = &[attr_data_path!(ep0_att1, ElementType::U16(0x1234))]; - im.handle_read_reqs(peer, input, expected); + im.handle_read_reqs(&handler, input, expected); // Add ACL to allow our peer to also access endpoint 1 let mut acl = AclEntry::new(1, Privilege::ADMIN, AuthMode::Case); - acl.add_subject(peer).unwrap(); + acl.add_subject(IM_ENGINE_PEER_ID).unwrap(); acl.add_target(Target::new(Some(1), None, None)).unwrap(); im.matter.acl_mgr.borrow_mut().add(acl).unwrap(); @@ -95,7 +92,7 @@ fn wc_read_attribute() { attr_data_path!(ep0_att1, ElementType::U16(0x1234)), attr_data_path!(ep1_att1, ElementType::U16(0x1234)), ]; - im.handle_read_reqs(peer, input, expected); + im.handle_read_reqs(&handler, input, expected); } #[test] @@ -115,25 +112,23 @@ fn exact_read_attribute() { Some(echo_cluster::AttributesDiscriminants::Att1 as u32), ); - let peer = 98765; - let mut mdns = DummyMdns {}; - let matter = matter(&mut mdns); - let mut im = ImEngine::new(&matter); + let im = ImEngine::new_default(); + let handler = im.handler(); // Test1: Unsupported Access error as no ACL matches let input = &[AttrPath::new(&wc_att1)]; let expected = &[attr_status!(&ep0_att1, IMStatusCode::UnsupportedAccess)]; - im.handle_read_reqs(peer, input, expected); + im.handle_read_reqs(&handler, input, expected); // Add ACL to allow our peer to access any endpoint let mut acl = AclEntry::new(1, Privilege::ADMIN, AuthMode::Case); - acl.add_subject(peer).unwrap(); + acl.add_subject(IM_ENGINE_PEER_ID).unwrap(); im.matter.acl_mgr.borrow_mut().add(acl).unwrap(); // Test2: Only Single response as only single endpoint is allowed let input = &[AttrPath::new(&wc_att1)]; let expected = &[attr_data_path!(ep0_att1, ElementType::U16(0x1234))]; - im.handle_read_reqs(peer, input, expected); + im.handle_read_reqs(&handler, input, expected); } #[test] @@ -177,52 +172,54 @@ fn wc_write_attribute() { EncodeValue::Closure(&attr_data1), )]; - let peer = 98765; - let mut mdns = DummyMdns {}; - let matter = matter(&mut mdns); - let mut im = ImEngine::new(&matter); + let im = ImEngine::new_default(); + let handler = im.handler(); // Test 1: Wildcard write to an attribute without permission should return // no error - im.handle_write_reqs(peer, None, input0, &[]); - assert_eq!(ATTR_WRITE_DEFAULT_VALUE, im.echo_cluster(0).att_write); + im.handle_write_reqs(&handler, input0, &[]); + assert_eq!( + ATTR_WRITE_DEFAULT_VALUE, + handler.echo_cluster(0).att_write.get() + ); // Add ACL to allow our peer to access one endpoint let mut acl = AclEntry::new(1, Privilege::ADMIN, AuthMode::Case); - acl.add_subject(peer).unwrap(); + acl.add_subject(IM_ENGINE_PEER_ID).unwrap(); acl.add_target(Target::new(Some(0), None, None)).unwrap(); im.matter.acl_mgr.borrow_mut().add(acl).unwrap(); // Test 2: Wildcard write to attributes will only return attributes // where the writes were successful im.handle_write_reqs( - peer, - None, + &handler, input0, &[AttrStatus::new(&ep0_att, IMStatusCode::Success, 0)], ); - assert_eq!(val0, im.echo_cluster(0).att_write); - assert_eq!(ATTR_WRITE_DEFAULT_VALUE, im.echo_cluster(1).att_write); + assert_eq!(val0, handler.echo_cluster(0).att_write.get()); + assert_eq!( + ATTR_WRITE_DEFAULT_VALUE, + handler.echo_cluster(1).att_write.get() + ); // Add ACL to allow our peer to access another endpoint let mut acl = AclEntry::new(1, Privilege::ADMIN, AuthMode::Case); - acl.add_subject(peer).unwrap(); + acl.add_subject(IM_ENGINE_PEER_ID).unwrap(); acl.add_target(Target::new(Some(1), None, None)).unwrap(); im.matter.acl_mgr.borrow_mut().add(acl).unwrap(); // Test 3: Wildcard write to attributes will return multiple attributes // where the writes were successful im.handle_write_reqs( - peer, - None, + &handler, input1, &[ AttrStatus::new(&ep0_att, IMStatusCode::Success, 0), AttrStatus::new(&ep1_att, IMStatusCode::Success, 0), ], ); - assert_eq!(val1, im.echo_cluster(0).att_write); - assert_eq!(val1, im.echo_cluster(1).att_write); + assert_eq!(val1, handler.echo_cluster(0).att_write.get()); + assert_eq!(val1, handler.echo_cluster(1).att_write.get()); } #[test] @@ -253,25 +250,26 @@ fn exact_write_attribute() { )]; let expected_success = &[AttrStatus::new(&ep0_att, IMStatusCode::Success, 0)]; - let peer = 98765; - let mut mdns = DummyMdns {}; - let matter = matter(&mut mdns); - let mut im = ImEngine::new(&matter); + let im = ImEngine::new_default(); + let handler = im.handler(); // Test 1: Exact write to an attribute without permission should return // Unsupported Access Error - im.handle_write_reqs(peer, None, input, expected_fail); - assert_eq!(ATTR_WRITE_DEFAULT_VALUE, im.echo_cluster(0).att_write); + im.handle_write_reqs(&handler, input, expected_fail); + assert_eq!( + ATTR_WRITE_DEFAULT_VALUE, + handler.echo_cluster(0).att_write.get() + ); // Add ACL to allow our peer to access any endpoint let mut acl = AclEntry::new(1, Privilege::ADMIN, AuthMode::Case); - acl.add_subject(peer).unwrap(); + acl.add_subject(IM_ENGINE_PEER_ID).unwrap(); im.matter.acl_mgr.borrow_mut().add(acl).unwrap(); // Test 1: Exact write to an attribute with permission should grant // access - im.handle_write_reqs(peer, None, input, expected_success); - assert_eq!(val0, im.echo_cluster(0).att_write); + im.handle_write_reqs(&handler, input, expected_success); + assert_eq!(val0, handler.echo_cluster(0).att_write.get()); } #[test] @@ -303,19 +301,20 @@ fn exact_write_attribute_noc_cat() { )]; let expected_success = &[AttrStatus::new(&ep0_att, IMStatusCode::Success, 0)]; - let peer = 98765; /* CAT in NOC is 1 more, in version, than that in ACL */ let noc_cat = gen_noc_cat(0xABCD, 2); let cat_in_acl = gen_noc_cat(0xABCD, 1); let cat_ids = [noc_cat, 0, 0]; - let mut mdns = DummyMdns; - let matter = matter(&mut mdns); - let mut im = ImEngine::new(&matter); + let im = ImEngine::new(cat_ids); + let handler = im.handler(); // Test 1: Exact write to an attribute without permission should return // Unsupported Access Error - im.handle_write_reqs(peer, Some(&cat_ids), input, expected_fail); - assert_eq!(ATTR_WRITE_DEFAULT_VALUE, im.echo_cluster(0).att_write); + im.handle_write_reqs(&handler, input, expected_fail); + assert_eq!( + ATTR_WRITE_DEFAULT_VALUE, + handler.echo_cluster(0).att_write.get() + ); // Add ACL to allow our peer to access any endpoint let mut acl = AclEntry::new(1, Privilege::ADMIN, AuthMode::Case); @@ -324,8 +323,8 @@ fn exact_write_attribute_noc_cat() { // Test 1: Exact write to an attribute with permission should grant // access - im.handle_write_reqs(peer, Some(&cat_ids), input, expected_success); - assert_eq!(val0, im.echo_cluster(0).att_write); + im.handle_write_reqs(&handler, input, expected_success); + assert_eq!(val0, handler.echo_cluster(0).att_write.get()); } #[test] @@ -347,21 +346,18 @@ fn insufficient_perms_write() { EncodeValue::Closure(&attr_data0), )]; - let peer = 98765; - let mut mdns = DummyMdns {}; - let matter = matter(&mut mdns); - let mut im = ImEngine::new(&matter); + let im = ImEngine::new_default(); + let handler = im.handler(); // Add ACL to allow our peer with only OPERATE permission let mut acl = AclEntry::new(1, Privilege::OPERATE, AuthMode::Case); - acl.add_subject(peer).unwrap(); + acl.add_subject(IM_ENGINE_PEER_ID).unwrap(); acl.add_target(Target::new(Some(0), None, None)).unwrap(); im.matter.acl_mgr.borrow_mut().add(acl).unwrap(); // Test: Not enough permission should return error im.handle_write_reqs( - peer, - None, + &handler, input0, &[AttrStatus::new( &ep0_att, @@ -369,7 +365,10 @@ fn insufficient_perms_write() { 0, )], ); - assert_eq!(ATTR_WRITE_DEFAULT_VALUE, im.echo_cluster(0).att_write); + assert_eq!( + ATTR_WRITE_DEFAULT_VALUE, + handler.echo_cluster(0).att_write.get() + ); } #[test] @@ -381,10 +380,9 @@ fn insufficient_perms_write() { /// - Write Attr to Echo Cluster again (successful this time) fn write_with_runtime_acl_add() { init_env_logger(); - let peer = 98765; - let mut mdns = DummyMdns {}; - let matter = matter(&mut mdns); - let mut im = ImEngine::new(&matter); + + let im = ImEngine::new_default(); + let handler = im.handler(); let val0 = 10; let attr_data0 = |tag, t: &mut TLVWriter| { @@ -403,7 +401,7 @@ fn write_with_runtime_acl_add() { // Create ACL to allow our peer ADMIN on everything let mut allow_acl = AclEntry::new(1, Privilege::ADMIN, AuthMode::Case); - allow_acl.add_subject(peer).unwrap(); + allow_acl.add_subject(IM_ENGINE_PEER_ID).unwrap(); let acl_att = GenericPath::new( Some(0), @@ -418,7 +416,7 @@ fn write_with_runtime_acl_add() { // Create ACL that only allows write to the ACL Cluster let mut basic_acl = AclEntry::new(1, Privilege::ADMIN, AuthMode::Case); - basic_acl.add_subject(peer).unwrap(); + basic_acl.add_subject(IM_ENGINE_PEER_ID).unwrap(); basic_acl .add_target(Target::new(Some(0), Some(access_control::ID), None)) .unwrap(); @@ -426,8 +424,7 @@ fn write_with_runtime_acl_add() { // Test: deny write (with error), then ACL is added, then allow write im.handle_write_reqs( - peer, - None, + &handler, // write to echo-cluster attribute, write to acl attribute, write to echo-cluster attribute &[input0.clone(), acl_input, input0], &[ @@ -436,7 +433,7 @@ fn write_with_runtime_acl_add() { AttrStatus::new(&ep0_att, IMStatusCode::Success, 0), ], ); - assert_eq!(val0, im.echo_cluster(0).att_write); + assert_eq!(val0, handler.echo_cluster(0).att_write.get()); } #[test] @@ -448,10 +445,9 @@ fn test_read_data_ver() { // - wildcard endpoint, att1 // - 2 responses are expected init_env_logger(); - let peer = 98765; - let mut mdns = DummyMdns {}; - let matter = matter(&mut mdns); - let mut im = ImEngine::new(&matter); + + let im = ImEngine::new_default(); + let handler = im.handler(); // Add ACL to allow our peer with only OPERATE permission let acl = AclEntry::new(1, Privilege::OPERATE, AuthMode::Case); @@ -482,10 +478,11 @@ fn test_read_data_ver() { ElementType::U16(0x1234) ), ]; - let mut out_buf = [0u8; 400]; + + let mut out = heapless::Vec::new(); // Test 1: Simple read to retrieve the current Data Version of Cluster at Endpoint 0 - let received = im.gen_read_reqs_output(peer, input, None, &mut out_buf); + let received = im.gen_read_reqs_output::<1>(&handler, input, None, &mut out); assert_attr_report(&received, expected); let data_ver_cluster_at_0 = received @@ -507,11 +504,12 @@ fn test_read_data_ver() { }]; // Test 2: Add Dataversion filter for cluster at endpoint 0 only single entry should be retrieved - let received = im.gen_read_reqs_output( - peer, + let mut out = heapless::Vec::new(); + let received = im.gen_read_reqs_output::<1>( + &handler, input, Some(TLVArray::Slice(&dataver_filter)), - &mut out_buf, + &mut out, ); let expected_only_one = &[attr_data_path!( GenericPath::new( @@ -532,10 +530,10 @@ fn test_read_data_ver() { ); let input = &[AttrPath::new(&ep0_att1)]; let received = im.gen_read_reqs_output( - peer, + &handler, input, Some(TLVArray::Slice(&dataver_filter)), - &mut out_buf, + &mut out, ); let expected_error = &[]; @@ -551,10 +549,9 @@ fn test_write_data_ver() { // - wildcard endpoint, att1 // - 2 responses are expected init_env_logger(); - let peer = 98765; - let mut mdns = DummyMdns {}; - let matter = matter(&mut mdns); - let mut im = ImEngine::new(&matter); + + let im = ImEngine::new_default(); + let handler = im.handler(); // Add ACL to allow our peer with only OPERATE permission let acl = AclEntry::new(1, Privilege::ADMIN, AuthMode::Case); @@ -576,7 +573,7 @@ fn test_write_data_ver() { let attr_data0 = EncodeValue::Value(&val0); let attr_data1 = EncodeValue::Value(&val1); - let initial_data_ver = im.echo_cluster(0).data_ver.get(); + let initial_data_ver = handler.echo_cluster(0).data_ver.get(); // Test 1: Write with correct dataversion should succeed let input_correct_dataver = &[AttrData::new( @@ -585,12 +582,11 @@ fn test_write_data_ver() { attr_data0, )]; im.handle_write_reqs( - peer, - None, + &handler, input_correct_dataver, &[AttrStatus::new(&ep0_attwrite, IMStatusCode::Success, 0)], ); - assert_eq!(val0, im.echo_cluster(0).att_write); + assert_eq!(val0, handler.echo_cluster(0).att_write.get()); // Test 2: Write with incorrect dataversion should fail // Now the data version would have incremented due to the previous write @@ -600,8 +596,7 @@ fn test_write_data_ver() { attr_data1.clone(), )]; im.handle_write_reqs( - peer, - None, + &handler, input_correct_dataver, &[AttrStatus::new( &ep0_attwrite, @@ -609,12 +604,12 @@ fn test_write_data_ver() { 0, )], ); - assert_eq!(val0, im.echo_cluster(0).att_write); + assert_eq!(val0, handler.echo_cluster(0).att_write.get()); // Test 3: Wildcard write with incorrect dataversion should ignore that cluster // In this case, while the data version is correct for endpoint 0, the endpoint 1's // data version would not match - let new_data_ver = im.echo_cluster(0).data_ver.get(); + let new_data_ver = handler.echo_cluster(0).data_ver.get(); let input_correct_dataver = &[AttrData::new( Some(new_data_ver), @@ -622,12 +617,11 @@ fn test_write_data_ver() { attr_data1, )]; im.handle_write_reqs( - peer, - None, + &handler, input_correct_dataver, &[AttrStatus::new(&ep0_attwrite, IMStatusCode::Success, 0)], ); - assert_eq!(val1, im.echo_cluster(0).att_write); + assert_eq!(val1, handler.echo_cluster(0).att_write.get()); assert_eq!(initial_data_ver + 1, new_data_ver); } diff --git a/matter/tests/data_model/attribute_lists.rs b/matter/tests/data_model/attribute_lists.rs index 636c9c0a..12d4a5d2 100644 --- a/matter/tests/data_model/attribute_lists.rs +++ b/matter/tests/data_model/attribute_lists.rs @@ -22,13 +22,12 @@ use matter::{ messages::ib::{AttrData, AttrPath, AttrStatus}, messages::GenericPath, }, - mdns::DummyMdns, tlv::Nullable, }; use crate::common::{ echo_cluster::{self, TestChecker}, - im_engine::{matter, ImEngine}, + im_engine::ImEngine, init_env_logger, }; @@ -65,8 +64,8 @@ fn attr_list_ops() { EncodeValue::Value(&val0), )]; let expected = &[AttrStatus::new(&att_data, IMStatusCode::Success, 0)]; - ImEngine::new_with_write_reqs(&matter(&mut DummyMdns), input, expected); + ImEngine::write_reqs(input, expected); { let tc = tc_handle.lock().unwrap(); assert_eq!([Some(val0), None, None, None, None], tc.write_list); @@ -79,8 +78,8 @@ fn attr_list_ops() { EncodeValue::Value(&val1), )]; let expected = &[AttrStatus::new(&att_data, IMStatusCode::Success, 0)]; - ImEngine::new_with_write_reqs(&matter(&mut DummyMdns), input, expected); + ImEngine::write_reqs(input, expected); { let tc = tc_handle.lock().unwrap(); assert_eq!([Some(val0), Some(val1), None, None, None], tc.write_list); @@ -94,8 +93,8 @@ fn attr_list_ops() { EncodeValue::Value(&val0), )]; let expected = &[AttrStatus::new(&att_data, IMStatusCode::Success, 0)]; - ImEngine::new_with_write_reqs(&matter(&mut DummyMdns), input, expected); + ImEngine::write_reqs(input, expected); { let tc = tc_handle.lock().unwrap(); assert_eq!([Some(val0), Some(val0), None, None, None], tc.write_list); @@ -105,8 +104,8 @@ fn attr_list_ops() { att_path.list_index = Some(Nullable::NotNull(0)); let input = &[AttrData::new(None, att_path.clone(), delete_item)]; let expected = &[AttrStatus::new(&att_data, IMStatusCode::Success, 0)]; - ImEngine::new_with_write_reqs(&matter(&mut DummyMdns), input, expected); + ImEngine::write_reqs(input, expected); { let tc = tc_handle.lock().unwrap(); assert_eq!([None, Some(val0), None, None, None], tc.write_list); @@ -121,8 +120,8 @@ fn attr_list_ops() { EncodeValue::Value(&overwrite_val), )]; let expected = &[AttrStatus::new(&att_data, IMStatusCode::Success, 0)]; - ImEngine::new_with_write_reqs(&matter(&mut DummyMdns), input, expected); + ImEngine::write_reqs(input, expected); { let tc = tc_handle.lock().unwrap(); assert_eq!([Some(20), Some(21), None, None, None], tc.write_list); @@ -132,8 +131,8 @@ fn attr_list_ops() { att_path.list_index = None; let input = &[AttrData::new(None, att_path, delete_all)]; let expected = &[AttrStatus::new(&att_data, IMStatusCode::Success, 0)]; - ImEngine::new_with_write_reqs(&matter(&mut DummyMdns), input, expected); + ImEngine::write_reqs(input, expected); { let tc = tc_handle.lock().unwrap(); assert_eq!([None, None, None, None, None], tc.write_list); diff --git a/matter/tests/data_model/attributes.rs b/matter/tests/data_model/attributes.rs index 7d185268..87bd96df 100644 --- a/matter/tests/data_model/attributes.rs +++ b/matter/tests/data_model/attributes.rs @@ -25,18 +25,12 @@ use matter::{ messages::ib::{AttrData, AttrPath, AttrResp, AttrStatus}, messages::GenericPath, }, - mdns::DummyMdns, tlv::{ElementType, TLVElement, TLVWriter, TagType}, }; use crate::{ attr_data, attr_data_path, attr_status, - common::{ - attributes::*, - echo_cluster, - im_engine::{matter, ImEngine}, - init_env_logger, - }, + common::{attributes::*, echo_cluster, im_engine::ImEngine, init_env_logger}, }; #[test] @@ -75,7 +69,7 @@ fn test_read_success() { ElementType::U32(echo_cluster::ATTR_CUSTOM_VALUE) ), ]; - ImEngine::new_with_read_reqs(&matter(&mut DummyMdns), input, expected); + ImEngine::read_reqs(input, expected); } #[test] @@ -122,7 +116,7 @@ fn test_read_unsupported_fields() { attr_status!(&invalid_cluster, IMStatusCode::UnsupportedCluster), attr_status!(&invalid_attribute, IMStatusCode::UnsupportedAttribute), ]; - ImEngine::new_with_read_reqs(&matter(&mut DummyMdns), input, expected); + ImEngine::read_reqs(input, expected); } #[test] @@ -153,7 +147,7 @@ fn test_read_wc_endpoint_all_have_clusters() { ElementType::U16(0x1234) ), ]; - ImEngine::new_with_read_reqs(&matter(&mut DummyMdns), input, expected); + ImEngine::read_reqs(input, expected); } #[test] @@ -178,7 +172,7 @@ fn test_read_wc_endpoint_only_1_has_cluster() { ), ElementType::False )]; - ImEngine::new_with_read_reqs(&matter(&mut DummyMdns), input, expected); + ImEngine::read_reqs(input, expected); } #[test] @@ -285,7 +279,7 @@ fn test_read_wc_endpoint_wc_attribute() { ElementType::U32(echo_cluster::ATTR_CUSTOM_VALUE) ), ]; - ImEngine::new_with_read_reqs(&matter(&mut DummyMdns), input, expected); + ImEngine::read_reqs(input, expected); } #[test] @@ -331,11 +325,14 @@ fn test_write_success() { AttrStatus::new(&ep1_att, IMStatusCode::Success, 0), ]; - let mut mdns = DummyMdns; - let matter = matter(&mut mdns); - let im = ImEngine::new_with_write_reqs(&matter, input, expected); - assert_eq!(val0, im.echo_cluster(0).att_write); - assert_eq!(val1, im.echo_cluster(1).att_write); + let im = ImEngine::new_default(); + let handler = im.handler(); + + im.add_default_acl(); + im.handle_write_reqs(&handler, input, expected); + + assert_eq!(val0, handler.echo_cluster(0).att_write.get()); + assert_eq!(val1, handler.echo_cluster(1).att_write.get()); } #[test] @@ -375,10 +372,13 @@ fn test_write_wc_endpoint() { AttrStatus::new(&ep1_att, IMStatusCode::Success, 0), ]; - let mut mdns = DummyMdns; - let matter = matter(&mut mdns); - let im = ImEngine::new_with_write_reqs(&matter, input, expected); - assert_eq!(val0, im.echo_cluster(0).att_write); + let im = ImEngine::new_default(); + let handler = im.handler(); + + im.add_default_acl(); + im.handle_write_reqs(&handler, input, expected); + + assert_eq!(val0, handler.echo_cluster(0).att_write.get()); } #[test] @@ -467,11 +467,14 @@ fn test_write_unsupported_fields() { AttrStatus::new(&wc_cluster, IMStatusCode::UnsupportedCluster, 0), AttrStatus::new(&wc_attribute, IMStatusCode::UnsupportedAttribute, 0), ]; - let mut mdns = DummyMdns; - let matter = matter(&mut mdns); - let im = ImEngine::new_with_write_reqs(&matter, input, expected); + let im = ImEngine::new_default(); + let handler = im.handler(); + + im.add_default_acl(); + im.handle_write_reqs(&handler, input, expected); + assert_eq!( echo_cluster::ATTR_WRITE_DEFAULT_VALUE, - im.echo_cluster(0).att_write + handler.echo_cluster(0).att_write.get() ); } diff --git a/matter/tests/data_model/commands.rs b/matter/tests/data_model/commands.rs index 0d9c0c3f..ee917713 100644 --- a/matter/tests/data_model/commands.rs +++ b/matter/tests/data_model/commands.rs @@ -17,12 +17,7 @@ use crate::{ cmd_data, - common::{ - commands::*, - echo_cluster, - im_engine::{matter, ImEngine}, - init_env_logger, - }, + common::{commands::*, echo_cluster, im_engine::ImEngine, init_env_logger}, echo_req, echo_resp, }; @@ -32,7 +27,6 @@ use matter::{ core::IMStatusCode, messages::ib::{CmdData, CmdPath, CmdStatus}, }, - mdns::DummyMdns, }; #[test] @@ -44,7 +38,7 @@ fn test_invoke_cmds_success() { let input = &[echo_req!(0, 5), echo_req!(1, 10)]; let expected = &[echo_resp!(0, 10), echo_resp!(1, 30)]; - ImEngine::new_with_commands(&matter(&mut DummyMdns), input, expected); + ImEngine::commands(input, expected); } #[test] @@ -99,7 +93,7 @@ fn test_invoke_cmds_unsupported_fields() { 0, )), ]; - ImEngine::new_with_commands(&matter(&mut DummyMdns), input, expected); + ImEngine::commands(input, expected); } #[test] @@ -114,7 +108,7 @@ fn test_invoke_cmd_wc_endpoint_all_have_clusters() { ); let input = &[cmd_data!(path, 5)]; let expected = &[echo_resp!(0, 10), echo_resp!(1, 15)]; - ImEngine::new_with_commands(&matter(&mut DummyMdns), input, expected); + ImEngine::commands(input, expected); } #[test] @@ -139,5 +133,5 @@ fn test_invoke_cmd_wc_endpoint_only_1_has_cluster() { IMStatusCode::Success, 0, ))]; - ImEngine::new_with_commands(&matter(&mut DummyMdns), input, expected); + ImEngine::commands(input, expected); } diff --git a/matter/tests/data_model/long_reads.rs b/matter/tests/data_model/long_reads.rs index 21c25595..e8382e06 100644 --- a/matter/tests/data_model/long_reads.rs +++ b/matter/tests/data_model/long_reads.rs @@ -30,13 +30,7 @@ use matter::{ }, messages::{msg::SubscribeReq, GenericPath}, }, - mdns::DummyMdns, - tlv::{self, ElementType, FromTLV, TLVElement, TagType, ToTLV}, - transport::{ - exchange::{self, Exchange}, - packet::MAX_RX_BUF_SIZE, - }, - Matter, + tlv::{self, ElementType, FromTLV, TLVElement, TagType}, }; use crate::{ @@ -44,35 +38,11 @@ use crate::{ common::{ attributes::*, echo_cluster as echo, - im_engine::{matter, ImEngine, ImInput}, + im_engine::{ImEngine, ImInput}, init_env_logger, }, }; -pub struct LongRead<'a> { - im_engine: ImEngine<'a>, -} - -impl<'a> LongRead<'a> { - pub fn new(matter: &'a Matter<'a>) -> Self { - let mut im_engine = ImEngine::new(matter); - // Use the same exchange for all parts of the transaction - im_engine.exch = Some(Exchange::new(1, 0, exchange::Role::Responder)); - Self { im_engine } - } - - pub fn process<'p>( - &mut self, - action: OpCode, - data: &dyn ToTLV, - data_out: &'p mut [u8], - ) -> (u8, &'p [u8]) { - let input = ImInput::new(action, data); - let (response, output) = self.im_engine.process(&input, data_out); - (response, output) - } -} - fn wildcard_read_resp(part: u8) -> Vec> { // For brevity, we only check the AttrPath, not the actual 'data' let dont_care = ElementType::U8(0); @@ -215,6 +185,9 @@ fn wildcard_read_resp(part: u8) -> Vec> { acl::AttributesDiscriminants::Extension, dont_care.clone() ), + ]; + + let part2 = vec![ attr_data!( 0, 31, @@ -266,9 +239,6 @@ fn wildcard_read_resp(part: u8) -> Vec> { descriptor::Attributes::DeviceTypeList, dont_care.clone() ), - ]; - - let part2 = vec![ attr_data!(1, 29, descriptor::Attributes::ServerList, dont_care.clone()), attr_data!(1, 29, descriptor::Attributes::PartsList, dont_care.clone()), attr_data!(1, 29, descriptor::Attributes::ClientList, dont_care.clone()), @@ -318,74 +288,103 @@ fn wildcard_read_resp(part: u8) -> Vec> { fn test_long_read_success() { // Read the entire attribute database, which requires 2 reads to complete init_env_logger(); - let mut mdns = DummyMdns; - let matter = matter(&mut mdns); - let mut lr = LongRead::new(&matter); - let mut output = [0_u8; MAX_RX_BUF_SIZE + 100]; + + let mut out = heapless::Vec::<_, 3>::new(); + let im = ImEngine::new_default(); + let handler = im.handler(); + + im.add_default_acl(); let wc_path = GenericPath::new(None, None, None); let read_all = [AttrPath::new(&wc_path)]; let read_req = ReadReq::new(true).set_attr_requests(&read_all); let expected_part1 = wildcard_read_resp(1); - let (out_code, out_data) = lr.process(OpCode::ReadRequest, &read_req, &mut output); - let root = tlv::get_root_node_struct(out_data).unwrap(); - let report_data = ReportDataMsg::from_tlv(&root).unwrap(); - assert_attr_report_skip_data(&report_data, &expected_part1); - assert_eq!(report_data.more_chunks, Some(true)); - assert_eq!(out_code, OpCode::ReportData as u8); - // Ask for the next read by sending a status report let status_report = StatusResp { status: IMStatusCode::Success, }; let expected_part2 = wildcard_read_resp(2); - let (out_code, out_data) = lr.process(OpCode::StatusResponse, &status_report, &mut output); - let root = tlv::get_root_node_struct(out_data).unwrap(); + + im.process( + &handler, + &[ + &ImInput::new(OpCode::ReadRequest, &read_req), + &ImInput::new(OpCode::StatusResponse, &status_report), + ], + &mut out, + ) + .unwrap(); + + assert_eq!(out.len(), 2); + + assert_eq!(out[0].action, OpCode::ReportData); + + let root = tlv::get_root_node_struct(&out[0].data).unwrap(); + let report_data = ReportDataMsg::from_tlv(&root).unwrap(); + assert_attr_report_skip_data(&report_data, &expected_part1); + assert_eq!(report_data.more_chunks, Some(true)); + + assert_eq!(out[1].action, OpCode::ReportData); + + let root = tlv::get_root_node_struct(&out[1].data).unwrap(); let report_data = ReportDataMsg::from_tlv(&root).unwrap(); assert_attr_report_skip_data(&report_data, &expected_part2); assert_eq!(report_data.more_chunks, None); - assert_eq!(out_code, OpCode::ReportData as u8); } #[test] fn test_long_read_subscription_success() { // Subscribe to the entire attribute database, which requires 2 reads to complete init_env_logger(); - let mut mdns = DummyMdns; - let matter = matter(&mut mdns); - let mut lr = LongRead::new(&matter); - let mut output = [0_u8; MAX_RX_BUF_SIZE + 100]; + + let mut out = heapless::Vec::<_, 3>::new(); + let im = ImEngine::new_default(); + let handler = im.handler(); + + im.add_default_acl(); let wc_path = GenericPath::new(None, None, None); let read_all = [AttrPath::new(&wc_path)]; let subs_req = SubscribeReq::new(true, 1, 20).set_attr_requests(&read_all); let expected_part1 = wildcard_read_resp(1); - let (out_code, out_data) = lr.process(OpCode::SubscribeRequest, &subs_req, &mut output); - let root = tlv::get_root_node_struct(out_data).unwrap(); - let report_data = ReportDataMsg::from_tlv(&root).unwrap(); - assert_attr_report_skip_data(&report_data, &expected_part1); - assert_eq!(report_data.more_chunks, Some(true)); - assert_eq!(out_code, OpCode::ReportData as u8); - // Ask for the next read by sending a status report let status_report = StatusResp { status: IMStatusCode::Success, }; let expected_part2 = wildcard_read_resp(2); - let (out_code, out_data) = lr.process(OpCode::StatusResponse, &status_report, &mut output); - let root = tlv::get_root_node_struct(out_data).unwrap(); + + im.process( + &handler, + &[ + &ImInput::new(OpCode::SubscribeRequest, &subs_req), + &ImInput::new(OpCode::StatusResponse, &status_report), + &ImInput::new(OpCode::StatusResponse, &status_report), + ], + &mut out, + ) + .unwrap(); + + assert_eq!(out.len(), 3); + + assert_eq!(out[0].action, OpCode::ReportData); + + let root = tlv::get_root_node_struct(&out[0].data).unwrap(); + let report_data = ReportDataMsg::from_tlv(&root).unwrap(); + assert_attr_report_skip_data(&report_data, &expected_part1); + assert_eq!(report_data.more_chunks, Some(true)); + + assert_eq!(out[1].action, OpCode::ReportData); + + let root = tlv::get_root_node_struct(&out[1].data).unwrap(); let report_data = ReportDataMsg::from_tlv(&root).unwrap(); assert_attr_report_skip_data(&report_data, &expected_part2); assert_eq!(report_data.more_chunks, None); - assert_eq!(out_code, OpCode::ReportData as u8); - // Finally confirm subscription - let (out_code, out_data) = lr.process(OpCode::StatusResponse, &status_report, &mut output); - tlv::print_tlv_list(out_data); - let root = tlv::get_root_node_struct(out_data).unwrap(); + assert_eq!(out[2].action, OpCode::SubscribeResponse); + + let root = tlv::get_root_node_struct(&out[2].data).unwrap(); let subs_resp = SubscribeResp::from_tlv(&root).unwrap(); - assert_eq!(out_code, OpCode::SubscribeResponse as u8); assert_eq!(subs_resp.subs_id, 1); } diff --git a/matter/tests/data_model/timed_requests.rs b/matter/tests/data_model/timed_requests.rs index 3f441901..e4eb960e 100644 --- a/matter/tests/data_model/timed_requests.rs +++ b/matter/tests/data_model/timed_requests.rs @@ -22,7 +22,6 @@ use matter::{ messages::ib::{AttrData, AttrPath, AttrStatus}, messages::{ib::CmdData, ib::CmdPath, GenericPath}, }, - mdns::DummyMdns, tlv::TLVWriter, }; @@ -31,7 +30,7 @@ use crate::{ commands::*, echo_cluster, handlers::{TimedInvResponse, WriteResponse}, - im_engine::{matter, ImEngine}, + im_engine::ImEngine, init_env_logger, }, echo_req, echo_resp, @@ -75,25 +74,20 @@ fn test_timed_write_fail_and_success() { ]; // Test with incorrect handling - ImEngine::new_with_timed_write_reqs( - &matter(&mut DummyMdns), - input, - &WriteResponse::TransactionError, - 400, - 500, - ); + ImEngine::timed_write_reqs(input, &WriteResponse::TransactionError, 100, 500); // Test with correct handling - let mut mdns = DummyMdns; - let matter = matter(&mut mdns); - let im = ImEngine::new_with_timed_write_reqs( - &matter, + let im = ImEngine::new_default(); + let handler = im.handler(); + im.add_default_acl(); + im.handle_timed_write_reqs( + &handler, input, &WriteResponse::TransactionSuccess(expected), 400, 0, ); - assert_eq!(val0, im.echo_cluster(0).att_write); + assert_eq!(val0, handler.echo_cluster(0).att_write.get()); } #[test] @@ -103,8 +97,7 @@ fn test_timed_cmd_success() { let input = &[echo_req!(0, 5), echo_req!(1, 10)]; let expected = &[echo_resp!(0, 10), echo_resp!(1, 30)]; - ImEngine::new_with_timed_commands( - &matter(&mut DummyMdns), + ImEngine::timed_commands( input, &TimedInvResponse::TransactionSuccess(expected), 400, @@ -119,11 +112,10 @@ fn test_timed_cmd_timeout() { init_env_logger(); let input = &[echo_req!(0, 5), echo_req!(1, 10)]; - ImEngine::new_with_timed_commands( - &matter(&mut DummyMdns), + ImEngine::timed_commands( input, &TimedInvResponse::TransactionError(IMStatusCode::Timeout), - 400, + 100, 500, true, ); @@ -135,8 +127,7 @@ fn test_timed_cmd_timedout_mismatch() { init_env_logger(); let input = &[echo_req!(0, 5), echo_req!(1, 10)]; - ImEngine::new_with_timed_commands( - &matter(&mut DummyMdns), + ImEngine::timed_commands( input, &TimedInvResponse::TransactionError(IMStatusCode::TimedRequestMisMatch), 400, @@ -145,8 +136,7 @@ fn test_timed_cmd_timedout_mismatch() { ); let input = &[echo_req!(0, 5), echo_req!(1, 10)]; - ImEngine::new_with_timed_commands( - &matter(&mut DummyMdns), + ImEngine::timed_commands( input, &TimedInvResponse::TransactionError(IMStatusCode::TimedRequestMisMatch), 0, diff --git a/matter/tests/interaction_model.rs b/matter/tests/interaction_model.rs deleted file mode 100644 index 9642ab23..00000000 --- a/matter/tests/interaction_model.rs +++ /dev/null @@ -1,152 +0,0 @@ -/* - * - * Copyright (c) 2020-2022 Project CHIP Authors - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -use matter::data_model::core::DataHandler; -use matter::error::Error; -use matter::interaction_model::core::Interaction; -use matter::interaction_model::core::InteractionModel; -use matter::interaction_model::core::OpCode; -use matter::interaction_model::core::Transaction; -use matter::transport::exchange::Exchange; -use matter::transport::exchange::ExchangeCtx; -use matter::transport::network::Address; -use matter::transport::network::IpAddr; -use matter::transport::network::Ipv4Addr; -use matter::transport::network::SocketAddr; -use matter::transport::packet::Packet; -use matter::transport::packet::MAX_RX_BUF_SIZE; -use matter::transport::packet::MAX_TX_BUF_SIZE; -use matter::transport::proto_ctx::ProtoCtx; -use matter::transport::session::SessionMgr; -use matter::utils::epoch::dummy_epoch; -use matter::utils::rand::dummy_rand; - -struct Node { - pub endpoint: u16, - pub cluster: u32, - pub command: u16, - pub variable: u8, -} - -struct DataModel { - node: Node, -} - -impl DataModel { - pub fn new(node: Node) -> Self { - DataModel { node } - } -} - -impl DataHandler for DataModel { - fn handle( - &mut self, - interaction: Interaction, - _tx: &mut Packet, - _transaction: &mut Transaction, - ) -> Result { - if let Interaction::Invoke(req) = interaction { - if let Some(inv_requests) = &req.inv_requests { - for i in inv_requests.iter() { - let data = if let Some(data) = i.data.unwrap_tlv() { - data - } else { - continue; - }; - let cmd_path_ib = i.path; - let common_data = &mut self.node; - common_data.endpoint = cmd_path_ib.path.endpoint.unwrap_or(1); - common_data.cluster = cmd_path_ib.path.cluster.unwrap_or(0); - common_data.command = cmd_path_ib.path.leaf.unwrap_or(0) as u16; - data.confirm_struct().unwrap(); - common_data.variable = data.find_tag(0).unwrap().u8().unwrap(); - } - } - } - - Ok(false) - } -} - -fn handle_data(action: OpCode, data_in: &[u8], data_out: &mut [u8]) -> (DataModel, usize) { - let data_model = DataModel::new(Node { - endpoint: 0, - cluster: 0, - command: 0, - variable: 0, - }); - let mut interaction_model = InteractionModel(data_model); - let mut exch: Exchange = Default::default(); - let mut sess_mgr = SessionMgr::new(dummy_epoch, dummy_rand); - let sess_idx = sess_mgr - .get_or_add( - 0, - Address::Udp(SocketAddr::new( - IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)), - 5542, - )), - None, - false, - ) - .unwrap(); - let sess = sess_mgr.get_session_handle(sess_idx); - let exch_ctx = ExchangeCtx { - exch: &mut exch, - sess, - epoch: dummy_epoch, - }; - let mut rx_buf = [0; MAX_RX_BUF_SIZE]; - let mut tx_buf = [0; MAX_TX_BUF_SIZE]; - let mut rx = Packet::new_rx(&mut rx_buf); - let mut tx = Packet::new_tx(&mut tx_buf); - // Create fake rx packet - rx.set_proto_id(0x01); - rx.set_proto_opcode(action as u8); - rx.peer = Address::default(); - let in_data_len = data_in.len(); - let rx_buf = rx.as_mut_slice(); - rx_buf[..in_data_len].copy_from_slice(data_in); - - let mut ctx = ProtoCtx::new(exch_ctx, &rx, &mut tx); - - interaction_model.handle(&mut ctx).unwrap(); - - let out_len = ctx.tx.as_mut_slice().len(); - data_out[..out_len].copy_from_slice(ctx.tx.as_mut_slice()); - (interaction_model.0, out_len) -} - -#[test] -fn test_valid_invoke_cmd() -> Result<(), Error> { - // An invoke command for endpoint 0, cluster 49, command 12 and a u8 variable value of 0x05 - - let b = [ - 0x15, 0x28, 0x00, 0x28, 0x01, 0x36, 0x02, 0x15, 0x37, 0x00, 0x25, 0x00, 0x00, 0x00, 0x26, - 0x01, 0x31, 0x00, 0x00, 0x00, 0x26, 0x02, 0x0c, 0x00, 0x00, 0x00, 0x18, 0x35, 0x01, 0x24, - 0x00, 0x05, 0x18, 0x18, 0x18, 0x18, - ]; - - let mut out_buf: [u8; 20] = [0; 20]; - - let (data_model, _) = handle_data(OpCode::InvokeRequest, &b, &mut out_buf); - let data = &data_model.node; - assert_eq!(data.endpoint, 0); - assert_eq!(data.cluster, 49); - assert_eq!(data.command, 12); - assert_eq!(data.variable, 5); - Ok(()) -} diff --git a/sdkconfig.defaults b/sdkconfig.defaults new file mode 100644 index 00000000..6ccea500 --- /dev/null +++ b/sdkconfig.defaults @@ -0,0 +1,6 @@ +# Workaround for https://github.com/espressif/esp-idf/issues/7631 +CONFIG_MBEDTLS_CERTIFICATE_BUNDLE=n +CONFIG_MBEDTLS_CERTIFICATE_BUNDLE_DEFAULT_FULL=n + +# Examples often require a larger than the default stack size for the main thread. +CONFIG_ESP_MAIN_TASK_STACK_SIZE=10000 From 9576fd8d9a10f62005c99a428960169b36fbfeaa Mon Sep 17 00:00:00 2001 From: ivmarkov Date: Fri, 30 Jun 2023 12:45:21 +0000 Subject: [PATCH 63/72] Fix #60 --- matter/src/mdns.rs | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/matter/src/mdns.rs b/matter/src/mdns.rs index d07ba107..c2ae5385 100644 --- a/matter/src/mdns.rs +++ b/matter/src/mdns.rs @@ -19,7 +19,7 @@ use core::fmt::Write; use crate::{data_model::cluster_basic_information::BasicInfoConfig, error::Error}; -#[cfg(all(feature = "std", feature = "astro-dnssd"))] +#[cfg(all(feature = "std", target_os = "macos"))] pub mod astro; pub mod builtin; pub mod proto; @@ -55,16 +55,16 @@ where } } -#[cfg(all(feature = "std", feature = "astro-dnssd"))] +#[cfg(all(feature = "std", target_os = "macos"))] pub type DefaultMdns<'a> = astro::Mdns<'a>; -#[cfg(all(feature = "std", feature = "astro-dnssd"))] +#[cfg(all(feature = "std", target_os = "macos"))] pub type DefaultMdnsRunner<'a> = astro::MdnsRunner<'a>; -#[cfg(not(all(feature = "std", feature = "astro-dnssd")))] +#[cfg(not(all(feature = "std", target_os = "macos")))] pub type DefaultMdns<'a> = builtin::Mdns<'a>; -#[cfg(not(all(feature = "std", feature = "astro-dnssd")))] +#[cfg(not(all(feature = "std", target_os = "macos")))] pub type DefaultMdnsRunner<'a> = builtin::MdnsRunner<'a>; pub struct DummyMdns; From 762438ca8ee739285dee3eb575b374040e3dce0b Mon Sep 17 00:00:00 2001 From: Kedar Sovani Date: Thu, 20 Jul 2023 10:13:46 +0530 Subject: [PATCH 64/72] on_off_light: Save ACLs and Fabrics to PSM --- examples/onoff_light/src/main.rs | 60 ++++++++++++++++++++++---------- matter/src/core.rs | 12 +++++++ matter/src/transport/core.rs | 22 +++--------- matter/src/transport/runner.rs | 4 +-- 4 files changed, 61 insertions(+), 37 deletions(-) diff --git a/examples/onoff_light/src/main.rs b/examples/onoff_light/src/main.rs index ecfc71eb..627eda6d 100644 --- a/examples/onoff_light/src/main.rs +++ b/examples/onoff_light/src/main.rs @@ -18,7 +18,7 @@ use core::borrow::Borrow; use core::pin::pin; -use embassy_futures::select::select; +use embassy_futures::select::select3; use log::info; use matter::core::{CommissioningData, Matter}; use matter::data_model::cluster_basic_information::BasicInfoConfig; @@ -29,6 +29,7 @@ use matter::data_model::root_endpoint; use matter::data_model::system_model::descriptor; use matter::error::Error; use matter::mdns::{DefaultMdns, DefaultMdnsRunner}; +use matter::persist::FilePsm; use matter::secure_channel::spake2p::VerifierData; use matter::transport::network::{Ipv4Addr, Ipv6Addr}; use matter::transport::runner::{RxBuf, TransportRunner, TxBuf}; @@ -73,6 +74,12 @@ fn run() -> Result<(), Error> { device_name: "OnOff Light", }; + let psm_path = std::env::temp_dir().join("matter-iot"); + info!("Persisting from/to {}", psm_path.display()); + + #[cfg(all(feature = "std", not(target_os = "espidf")))] + let psm = matter::persist::FilePsm::new(psm_path)?; + let (ipv4_addr, ipv6_addr, interface) = initialize_network()?; let mdns = DefaultMdns::new( @@ -124,16 +131,18 @@ fn run() -> Result<(), Error> { let mut tx_buf = TxBuf::uninit(); let mut rx_buf = RxBuf::uninit(); - // #[cfg(all(feature = "std", not(target_os = "espidf")))] - // { - // if let Some(data) = psm.load("acls", buf)? { - // matter.load_acls(data)?; - // } - - // if let Some(data) = psm.load("fabrics", buf)? { - // matter.load_fabrics(data)?; - // } - // } + #[cfg(all(feature = "std", not(target_os = "espidf")))] + { + let mut buf = [0; 4096]; + let buf = &mut buf; + if let Some(data) = psm.load("acls", buf)? { + matter.load_acls(data)?; + } + + if let Some(data) = psm.load("fabrics", buf)? { + matter.load_fabrics(data)?; + } + } let node = Node { id: 0, @@ -179,13 +188,8 @@ fn run() -> Result<(), Error> { // connect the pipes of the `run` method with your own UDP stack let mut mdns = pin!(mdns_runner.run_udp()); - select( - &mut transport, - &mut mdns, - //save(transport, &psm), - ) - .await - .unwrap() + let mut save = pin!(save(matter, &psm)); + select3(&mut transport, &mut mdns, &mut save).await.unwrap() }); // NOTE: For no_std, replace with your own no_std way of polling the future @@ -299,6 +303,26 @@ fn initialize_network() -> Result<(Ipv4Addr, Ipv6Addr, u32), Error> { Ok((ip, ipv6, 0 as _)) } +#[cfg(all(feature = "std", not(target_os = "espidf")))] +#[inline(never)] +async fn save(matter: &Matter<'_>, psm: &FilePsm) -> Result<(), Error> { + let mut buf = [0; 4096]; + let buf = &mut buf; + + loop { + matter.wait_changed().await; + if matter.is_changed() { + if let Some(data) = matter.store_acls(buf)? { + psm.store("acls", data)?; + } + + if let Some(data) = matter.store_fabrics(buf)? { + psm.store("fabrics", data)?; + } + } + } +} + #[cfg(target_os = "espidf")] #[inline(never)] fn initialize_logger() { diff --git a/matter/src/core.rs b/matter/src/core.rs index 35c86771..13c0930f 100644 --- a/matter/src/core.rs +++ b/matter/src/core.rs @@ -28,6 +28,7 @@ use crate::{ mdns::Mdns, pairing::{print_pairing_code_and_qr, DiscoveryCapabilities}, secure_channel::{pake::PaseMgr, spake2p::VerifierData}, + transport::exchange::Notification, utils::{epoch::Epoch, rand::Rand}, }; @@ -48,6 +49,7 @@ pub struct Matter<'a> { pub acl_mgr: RefCell, pub pase_mgr: RefCell, pub failsafe: RefCell, + pub persist_notification: Notification, pub mdns: &'a dyn Mdns, pub epoch: Epoch, pub rand: Rand, @@ -91,6 +93,7 @@ impl<'a> Matter<'a> { acl_mgr: RefCell::new(AclMgr::new()), pase_mgr: RefCell::new(PaseMgr::new(epoch, rand)), failsafe: RefCell::new(FailSafe::new()), + persist_notification: Notification::new(), mdns, epoch, rand, @@ -157,6 +160,15 @@ impl<'a> Matter<'a> { Ok(false) } } + pub fn notify_changed(&self) { + if self.is_changed() { + self.persist_notification.signal(()); + } + } + + pub async fn wait_changed(&self) { + self.persist_notification.wait().await + } } impl<'a> Borrow> for Matter<'a> { diff --git a/matter/src/transport/core.rs b/matter/src/transport/core.rs index 2a54b4a3..98c2fbab 100644 --- a/matter/src/transport/core.rs +++ b/matter/src/transport/core.rs @@ -58,7 +58,6 @@ pub struct Transport<'a> { matter: &'a Matter<'a>, pub(crate) exchanges: RefCell>, pub(crate) send_notification: Notification, - pub(crate) persist_notification: Notification, pub session_mgr: RefCell, } @@ -72,7 +71,6 @@ impl<'a> Transport<'a> { matter, exchanges: RefCell::new(heapless::Vec::new()), send_notification: Notification::new(), - persist_notification: Notification::new(), session_mgr: RefCell::new(SessionMgr::new(epoch, rand)), } } @@ -128,7 +126,7 @@ impl<'a> Transport<'a> { } } - self.notify_changed(); + self.matter().notify_changed(); } } @@ -142,7 +140,7 @@ impl<'a> Transport<'a> { construction_notification, }; - self.notify_changed(); + self.matter().notify_changed(); Ok(Some(constructor)) } else if src_rx.proto.proto_id == PROTO_ID_SECURE_CHANNEL @@ -169,7 +167,7 @@ impl<'a> Transport<'a> { } } - self.notify_changed(); + self.matter().notify_changed(); Ok(None) } @@ -232,7 +230,7 @@ impl<'a> Transport<'a> { }); if let Some(ctx) = ctx { - self.notify_changed(); + self.matter().notify_changed(); let state = &mut ctx.state; @@ -291,7 +289,7 @@ impl<'a> Transport<'a> { dest_tx.log("Sending packet"); self.pre_send(ctx, dest_tx)?; - self.notify_changed(); + self.matter().notify_changed(); return Ok(true); } @@ -414,14 +412,4 @@ impl<'a> Transport<'a> { ) -> Option<&'r mut ExchangeCtx> { exchanges.iter_mut().find(|exchange| exchange.id == *id) } - - pub fn notify_changed(&self) { - if self.matter().is_changed() { - self.persist_notification.signal(()); - } - } - - pub async fn wait_changed(&self) { - self.persist_notification.wait().await - } } diff --git a/matter/src/transport/runner.rs b/matter/src/transport/runner.rs index f94e819e..c4cd4ed7 100644 --- a/matter/src/transport/runner.rs +++ b/matter/src/transport/runner.rs @@ -370,7 +370,7 @@ impl<'a> TransportRunner<'a> { sc.handle(&mut exchange, &mut rx, &mut tx).await?; - transport.notify_changed(); + transport.matter().notify_changed(); } PROTO_ID_INTERACTION_MODEL => { let dm = DataModel::new(handler); @@ -380,7 +380,7 @@ impl<'a> TransportRunner<'a> { dm.handle(&mut exchange, &mut rx, &mut tx, &mut rx_status) .await?; - transport.notify_changed(); + transport.matter().notify_changed(); } other => { error!("Unknown Proto-ID: {}", other); From 0eecce5f8d55c2174649f81477d4f3a848f0ced9 Mon Sep 17 00:00:00 2001 From: ivmarkov Date: Fri, 14 Jul 2023 22:26:01 +0000 Subject: [PATCH 65/72] UDP stack based on embassy-net --- examples/onoff_light/src/main.rs | 25 +++- matter/Cargo.toml | 4 + matter/src/mdns/builtin.rs | 46 ++++--- matter/src/transport/network.rs | 32 +++++ matter/src/transport/runner.rs | 30 +++-- matter/src/transport/udp.rs | 217 ++++++++++++++++++++++++------- 6 files changed, 276 insertions(+), 78 deletions(-) diff --git a/examples/onoff_light/src/main.rs b/examples/onoff_light/src/main.rs index 627eda6d..93eab5f7 100644 --- a/examples/onoff_light/src/main.rs +++ b/examples/onoff_light/src/main.rs @@ -31,8 +31,9 @@ use matter::error::Error; use matter::mdns::{DefaultMdns, DefaultMdnsRunner}; use matter::persist::FilePsm; use matter::secure_channel::spake2p::VerifierData; -use matter::transport::network::{Ipv4Addr, Ipv6Addr}; +use matter::transport::network::{Ipv4Addr, Ipv6Addr, NetworkStack}; use matter::transport::runner::{RxBuf, TransportRunner, TxBuf}; +use matter::transport::udp::UdpBuffers; use matter::utils::select::EitherUnwrap; mod dev_att; @@ -131,6 +132,13 @@ fn run() -> Result<(), Error> { let mut tx_buf = TxBuf::uninit(); let mut rx_buf = RxBuf::uninit(); + // NOTE (no_std): If using the `embassy-net` UDP implementation, replace this dummy stack with the `embassy-net` one + // When using a custom UDP stack, remove this + let stack = NetworkStack::new(); + + let mut mdns_udp_buffers = UdpBuffers::new(); + let mut trans_udp_buffers = UdpBuffers::new(); + #[cfg(all(feature = "std", not(target_os = "espidf")))] { let mut buf = [0; 4096]; @@ -164,6 +172,9 @@ fn run() -> Result<(), Error> { let runner = &mut runner; let tx_buf = &mut tx_buf; let rx_buf = &mut rx_buf; + let stack = &stack; + let mdns_udp_buffers = &mut mdns_udp_buffers; + let trans_udp_buffers = &mut trans_udp_buffers; info!( "About to run wth node {:p}, handler {:p}, transport runner {:p}, mdns_runner {:p}", @@ -171,9 +182,11 @@ fn run() -> Result<(), Error> { ); let mut fut = pin!(async move { - // NOTE (no_std): On no_std, the `run_udp` is a no-op so you might want to replace it with `run` and - // connect the pipes of the `run` method with your own UDP stack + // NOTE: If using a custom UDP stack, replace `run_udp` with `run` + // and connect the pipes of the `run` method with the custom UDP stack let mut transport = pin!(runner.run_udp( + stack, + trans_udp_buffers, tx_buf, rx_buf, CommissioningData { @@ -184,9 +197,9 @@ fn run() -> Result<(), Error> { &handler, )); - // NOTE (no_std): On no_std, the `run_udp` is a no-op so you might want to replace it with `run` and - // connect the pipes of the `run` method with your own UDP stack - let mut mdns = pin!(mdns_runner.run_udp()); + // NOTE: If using a custom UDP stack, replace `run_udp` with `run` + // and connect the pipes of the `run` method with the custom UDP stack + let mut mdns = pin!(mdns_runner.run_udp(stack, mdns_udp_buffers)); let mut save = pin!(save(matter, &psm)); select3(&mut transport, &mut mdns, &mut save).await.unwrap() diff --git a/matter/Cargo.toml b/matter/Cargo.toml index 5b8816d8..87533b5d 100644 --- a/matter/Cargo.toml +++ b/matter/Cargo.toml @@ -24,6 +24,7 @@ nightly = [] crypto_openssl = ["alloc", "openssl", "foreign-types", "hmac", "sha2"] crypto_mbedtls = ["alloc", "mbedtls"] crypto_rustcrypto = ["alloc", "sha2", "hmac", "pbkdf2", "hkdf", "aes", "ccm", "p256", "elliptic-curve", "crypto-bigint", "x509-cert", "rand_core"] +embassy-net = ["dep:embassy-net", "dep:embassy-net-driver", "smoltcp"] [dependencies] matter_macro_derive = { path = "../matter_macro_derive" } @@ -46,6 +47,9 @@ embassy-time = { version = "0.1.1", features = ["generic-queue-8"] } embassy-sync = "0.2" critical-section = "1.1.1" domain = { version = "0.7.2", default_features = false, features = ["heapless"] } +embassy-net = { version = "0.1", features = ["udp", "igmp", "proto-ipv6", "medium-ethernet", "medium-ip"], optional = true } +embassy-net-driver = { version = "0.1", optional = true } +smoltcp = { version = "0.10", default-features = false, optional = true } # STD-only dependencies rand = { version = "0.8.5", optional = true } diff --git a/matter/src/mdns/builtin.rs b/matter/src/mdns/builtin.rs index 7b6f8912..c869218d 100644 --- a/matter/src/mdns/builtin.rs +++ b/matter/src/mdns/builtin.rs @@ -1,17 +1,15 @@ -use core::{cell::RefCell, mem::MaybeUninit, pin::pin}; +use core::{cell::RefCell, pin::pin}; use domain::base::name::FromStrError; use domain::base::{octets::ParseError, ShortBuf}; -use embassy_futures::select::{select, select3}; +use embassy_futures::select::select; use embassy_time::{Duration, Timer}; use log::info; use crate::data_model::cluster_basic_information::BasicInfoConfig; use crate::error::{Error, ErrorCode}; use crate::transport::network::{Address, IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr}; -use crate::transport::packet::{MAX_RX_BUF_SIZE, MAX_TX_BUF_SIZE}; use crate::transport::pipe::{Chunk, Pipe}; -use crate::transport::udp::UdpListener; use crate::utils::select::{EitherUnwrap, Notification}; use super::{ @@ -19,18 +17,14 @@ use super::{ Service, ServiceMode, }; -const IP_BIND_ADDR: IpAddr = IpAddr::V6(Ipv6Addr::UNSPECIFIED); - const IP_BROADCAST_ADDR: Ipv4Addr = Ipv4Addr::new(224, 0, 0, 251); const IPV6_BROADCAST_ADDR: Ipv6Addr = Ipv6Addr::new(0xff02, 0, 0, 0, 0, 0, 0, 0x00fb); const PORT: u16 = 5353; -type MdnsTxBuf = MaybeUninit<[u8; MAX_TX_BUF_SIZE]>; -type MdnsRxBuf = MaybeUninit<[u8; MAX_RX_BUF_SIZE]>; - pub struct Mdns<'a> { host: Host<'a>, + #[allow(unused)] interface: u32, dev_det: &'a BasicInfoConfig<'a>, matter_port: u16, @@ -108,9 +102,21 @@ impl<'a> MdnsRunner<'a> { Self(mdns) } - pub async fn run_udp(&mut self) -> Result<(), Error> { - let mut tx_buf = MdnsTxBuf::uninit(); - let mut rx_buf = MdnsRxBuf::uninit(); + #[cfg(any(feature = "std", feature = "embassy-net"))] + pub async fn run_udp( + &mut self, + stack: &crate::transport::network::NetworkStack, + buffers: &mut crate::transport::udp::UdpBuffers, + ) -> Result<(), Error> + where + D: crate::transport::network::NetworkStackMulticastDriver + + crate::transport::network::NetworkStackDriver + + 'static, + { + let mut tx_buf = + core::mem::MaybeUninit::<[u8; crate::transport::packet::MAX_TX_BUF_SIZE]>::uninit(); + let mut rx_buf = + core::mem::MaybeUninit::<[u8; crate::transport::packet::MAX_RX_BUF_SIZE]>::uninit(); let tx_buf = &mut tx_buf; let rx_buf = &mut rx_buf; @@ -121,10 +127,18 @@ impl<'a> MdnsRunner<'a> { let tx_pipe = &tx_pipe; let rx_pipe = &rx_pipe; - let mut udp = UdpListener::new(SocketAddr::new(IP_BIND_ADDR, PORT)).await?; + let mut udp = crate::transport::udp::UdpListener::new( + stack, + crate::transport::network::SocketAddr::new(IpAddr::V6(Ipv6Addr::UNSPECIFIED), PORT), + buffers, + ) + .await?; udp.join_multicast_v6(IPV6_BROADCAST_ADDR, self.0.interface)?; - udp.join_multicast_v4(IP_BROADCAST_ADDR, Ipv4Addr::from(self.0.host.ip))?; + udp.join_multicast_v4( + IP_BROADCAST_ADDR, + crate::transport::network::Ipv4Addr::from(self.0.host.ip), + )?; let udp = &udp; @@ -168,7 +182,9 @@ impl<'a> MdnsRunner<'a> { let mut run = pin!(async move { self.run(tx_pipe, rx_pipe).await }); - select3(&mut tx, &mut rx, &mut run).await.unwrap() + embassy_futures::select::select3(&mut tx, &mut rx, &mut run) + .await + .unwrap() } pub async fn run(&self, tx_pipe: &Pipe<'_>, rx_pipe: &Pipe<'_>) -> Result<(), Error> { diff --git a/matter/src/transport/network.rs b/matter/src/transport/network.rs index ba50386d..c3b71ee8 100644 --- a/matter/src/transport/network.rs +++ b/matter/src/transport/network.rs @@ -55,3 +55,35 @@ impl Debug for Address { } } } + +#[cfg(all(feature = "std", not(feature = "embassy-net")))] +pub use std_stack::*; + +#[cfg(feature = "embassy-net")] +pub use embassy_stack::*; + +#[cfg(all(feature = "std", not(feature = "embassy-net")))] +mod std_stack { + pub trait NetworkStackDriver {} + + impl NetworkStackDriver for () {} + + pub trait NetworkStackMulticastDriver {} + + impl NetworkStackMulticastDriver for () {} + + pub struct NetworkStack(D); + + impl NetworkStack<()> { + pub const fn new() -> Self { + Self(()) + } + } +} + +#[cfg(feature = "embassy-net")] +mod embassy_stack { + pub use embassy_net::Stack as NetworkStack; + pub use embassy_net_driver::Driver as NetworkStackDriver; + pub use smoltcp::phy::Device as NetworkStackMulticastDriver; +} diff --git a/matter/src/transport/runner.rs b/matter/src/transport/runner.rs index c4cd4ed7..ccb1034a 100644 --- a/matter/src/transport/runner.rs +++ b/matter/src/transport/runner.rs @@ -21,10 +21,9 @@ use crate::{ alloc, data_model::{core::DataModel, objects::DataModelHandler}, interaction_model::core::PROTO_ID_INTERACTION_MODEL, - transport::network::{Address, IpAddr, Ipv6Addr, SocketAddr}, CommissioningData, Matter, }; -use embassy_futures::select::{select, select3, select_slice, Either}; +use embassy_futures::select::{select, select_slice, Either}; use embassy_sync::{blocking_mutex::raw::NoopRawMutex, channel::Channel}; use log::{error, info, warn}; @@ -40,7 +39,6 @@ use super::{ exchange::{ExchangeCtr, Notification, MAX_EXCHANGES}, packet::{MAX_RX_STATUS_BUF_SIZE, MAX_TX_BUF_SIZE}, pipe::{Chunk, Pipe}, - udp::UdpListener, }; pub type TxBuf = MaybeUninit<[u8; MAX_TX_BUF_SIZE]>; @@ -103,20 +101,30 @@ impl<'a> TransportRunner<'a> { &self.transport } - pub async fn run_udp( + #[cfg(any(feature = "std", feature = "embassy-net"))] + pub async fn run_udp( &mut self, + stack: &crate::transport::network::NetworkStack, + buffers: &mut crate::transport::udp::UdpBuffers, tx_buf: &mut TxBuf, rx_buf: &mut RxBuf, dev_comm: CommissioningData, handler: &H, ) -> Result<(), Error> where + D: crate::transport::network::NetworkStackDriver, H: DataModelHandler, { - let udp = UdpListener::new(SocketAddr::new( - IpAddr::V6(Ipv6Addr::UNSPECIFIED), - self.transport.matter().port, - )) + let udp = crate::transport::udp::UdpListener::new( + stack, + crate::transport::network::SocketAddr::new( + crate::transport::network::IpAddr::V6( + crate::transport::network::Ipv6Addr::UNSPECIFIED, + ), + self.transport.matter().port, + ), + buffers, + ) .await?; let tx_pipe = Pipe::new(unsafe { tx_buf.assume_init_mut() }); @@ -154,7 +162,7 @@ impl<'a> TransportRunner<'a> { data.chunk = Some(Chunk { start: 0, end: len, - addr: Address::Udp(addr), + addr: crate::transport::network::Address::Udp(addr), }); rx_pipe.data_supplied_notification.signal(()); } @@ -166,7 +174,9 @@ impl<'a> TransportRunner<'a> { let mut run = pin!(async move { self.run(tx_pipe, rx_pipe, dev_comm, handler).await }); - select3(&mut tx, &mut rx, &mut run).await.unwrap() + embassy_futures::select::select3(&mut tx, &mut rx, &mut run) + .await + .unwrap() } pub async fn run( diff --git a/matter/src/transport/udp.rs b/matter/src/transport/udp.rs index 9b234894..e9e5811d 100644 --- a/matter/src/transport/udp.rs +++ b/matter/src/transport/udp.rs @@ -15,29 +15,45 @@ * limitations under the License. */ -#[cfg(feature = "std")] +#[cfg(all(feature = "std", not(feature = "embassy-net")))] pub use smol_udp::*; -#[cfg(not(feature = "std"))] -pub use dummy_udp::*; +#[cfg(feature = "embassy-net")] +pub use embassy_udp::*; -#[cfg(feature = "std")] +#[cfg(all(feature = "std", not(feature = "embassy-net")))] mod smol_udp { use crate::error::*; use log::{debug, info, warn}; use smol::net::UdpSocket; - use crate::transport::network::{Ipv4Addr, Ipv6Addr, SocketAddr}; + use crate::transport::network::{ + Ipv4Addr, Ipv6Addr, NetworkStack, NetworkStackDriver, NetworkStackMulticastDriver, + SocketAddr, + }; - pub struct UdpListener { - socket: UdpSocket, + pub struct UdpBuffers(()); + + impl UdpBuffers { + pub const fn new() -> Self { + Self(()) + } } - impl UdpListener { - pub async fn new(addr: SocketAddr) -> Result { - let listener = UdpListener { - socket: UdpSocket::bind((addr.ip(), addr.port())).await?, - }; + pub struct UdpListener<'a, D>(UdpSocket, &'a NetworkStack) + where + D: NetworkStackDriver; + + impl<'a, D> UdpListener<'a, D> + where + D: NetworkStackDriver + 'a, + { + pub async fn new( + stack: &'a NetworkStack, + addr: SocketAddr, + _buffers: &'a mut UdpBuffers, + ) -> Result, Error> { + let listener = UdpListener(UdpSocket::bind((addr.ip(), addr.port())).await?, stack); info!("Listening on {:?}", addr); @@ -48,8 +64,11 @@ mod smol_udp { &mut self, multiaddr: Ipv6Addr, interface: u32, - ) -> Result<(), Error> { - self.socket.join_multicast_v6(&multiaddr, interface)?; + ) -> Result<(), Error> + where + D: NetworkStackMulticastDriver + 'static, + { + self.0.join_multicast_v6(&multiaddr, interface)?; info!("Joined IPV6 multicast {}/{}", multiaddr, interface); @@ -60,9 +79,12 @@ mod smol_udp { &mut self, multiaddr: Ipv4Addr, interface: Ipv4Addr, - ) -> Result<(), Error> { + ) -> Result<(), Error> + where + D: NetworkStackMulticastDriver + 'static, + { #[cfg(not(target_os = "espidf"))] - self.socket.join_multicast_v4(multiaddr, interface)?; + self.0.join_multicast_v4(multiaddr, interface)?; // join_multicast_v4() is broken for ESP-IDF, most likely due to wrong `ip_mreq` signature in the `libc` crate // Note that also most *_multicast_v4 and *_multicast_v6 methods are broken as well in Rust STD for the ESP-IDF @@ -101,7 +123,7 @@ mod smol_udp { }; esp_setsockopt( - &mut self.socket, + &mut self.0, esp_idf_sys::IPPROTO_IP, esp_idf_sys::IP_ADD_MEMBERSHIP, mreq, @@ -114,18 +136,18 @@ mod smol_udp { } pub async fn recv(&self, in_buf: &mut [u8]) -> Result<(usize, SocketAddr), Error> { - let (size, addr) = self.socket.recv_from(in_buf).await.map_err(|e| { + let (len, addr) = self.0.recv_from(in_buf).await.map_err(|e| { warn!("Error on the network: {:?}", e); ErrorCode::Network })?; - debug!("Got packet {:?} from addr {:?}", &in_buf[..size], addr); + debug!("Got packet {:?} from addr {:?}", &in_buf[..len], addr); - Ok((size, addr)) + Ok((len, addr)) } pub async fn send(&self, addr: SocketAddr, out_buf: &[u8]) -> Result { - let len = self.socket.send_to(out_buf, addr).await.map_err(|e| { + let len = self.0.send_to(out_buf, addr).await.map_err(|e| { warn!("Error on the network: {:?}", e); ErrorCode::Network })?; @@ -143,35 +165,92 @@ mod smol_udp { } } -#[cfg(not(feature = "std"))] -mod dummy_udp { - use core::future::pending; +#[cfg(feature = "embassy-net")] +mod embassy_udp { + use core::mem::MaybeUninit; + + use embassy_net::udp::{PacketMetadata, UdpSocket}; + + use smoltcp::wire::{IpAddress, IpEndpoint, Ipv4Address, Ipv6Address}; use crate::error::*; - use log::{debug, info}; - use crate::transport::network::{Ipv4Addr, Ipv6Addr, SocketAddr}; + use log::{debug, info, warn}; - pub struct UdpListener {} + use crate::transport::network::{ + IpAddr, Ipv4Addr, Ipv6Addr, NetworkStack, NetworkStackDriver, NetworkStackMulticastDriver, + SocketAddr, + }; - impl UdpListener { - pub async fn new(addr: SocketAddr) -> Result { - let listener = UdpListener {}; + const RX_BUF_SIZE: usize = 4096; + const TX_BUF_SIZE: usize = 4096; - info!("Pretending to listen on {:?}", addr); + pub struct UdpBuffers { + rx_buffer: MaybeUninit<[u8; RX_BUF_SIZE]>, + tx_buffer: MaybeUninit<[u8; TX_BUF_SIZE]>, + rx_meta: [PacketMetadata; 16], + tx_meta: [PacketMetadata; 16], + } - Ok(listener) + impl UdpBuffers { + pub const fn new() -> Self { + Self { + rx_buffer: MaybeUninit::uninit(), + tx_buffer: MaybeUninit::uninit(), + + rx_meta: [PacketMetadata::EMPTY; 16], + tx_meta: [PacketMetadata::EMPTY; 16], + } + } + } + + pub struct UdpListener<'a, D>(UdpSocket<'a>, &'a NetworkStack) + where + D: NetworkStackDriver; + + impl<'a, D> UdpListener<'a, D> + where + D: NetworkStackDriver + 'a, + { + pub async fn new( + stack: &'a NetworkStack, + addr: SocketAddr, + buffers: &'a mut UdpBuffers, + ) -> Result, Error> { + let mut socket = UdpSocket::new( + stack, + &mut buffers.rx_meta, + unsafe { buffers.rx_buffer.assume_init_mut() }, + &mut buffers.tx_meta, + unsafe { buffers.tx_buffer.assume_init_mut() }, + ); + + socket.bind(addr.port()).map_err(|e| { + warn!("Error on the network: {:?}", e); + ErrorCode::Network + })?; + + info!("Listening on {:?}", addr); + + Ok(UdpListener(socket, stack)) } pub fn join_multicast_v6( &mut self, multiaddr: Ipv6Addr, - interface: u32, - ) -> Result<(), Error> { - info!( - "Pretending to join IPV6 multicast {}/{}", - multiaddr, interface - ); + _interface: u32, + ) -> Result<(), Error> + where + D: NetworkStackMulticastDriver + 'static, + { + self.1 + .join_multicast_group(Self::from_ip_addr(IpAddr::V6(multiaddr))) + .map_err(|e| { + warn!("Error on the network: {:?}", e); + ErrorCode::Network + })?; + + info!("Joined IP multicast {}", multiaddr); Ok(()) } @@ -179,23 +258,45 @@ mod dummy_udp { pub fn join_multicast_v4( &mut self, multiaddr: Ipv4Addr, - interface: Ipv4Addr, - ) -> Result<(), Error> { - info!( - "Pretending to join IP multicast {}/{}", - multiaddr, interface - ); + _interface: Ipv4Addr, + ) -> Result<(), Error> + where + D: NetworkStackMulticastDriver + 'static, + { + self.1 + .join_multicast_group(Self::from_ip_addr(IpAddr::V4(multiaddr))) + .map_err(|e| { + warn!("Error on the network: {:?}", e); + ErrorCode::Network + })?; + + info!("Joined IP multicast {}", multiaddr); Ok(()) } - pub async fn recv(&self, _in_buf: &mut [u8]) -> Result<(usize, SocketAddr), Error> { - info!("Pretending to wait for incoming packets (looping forever)"); + pub async fn recv(&self, in_buf: &mut [u8]) -> Result<(usize, SocketAddr), Error> { + let (len, ep) = self.0.recv_from(in_buf).await.map_err(|e| { + warn!("Error on the network: {:?}", e); + ErrorCode::Network + })?; + + let addr = Self::to_socket_addr(ep); - pending().await + debug!("Got packet {:?} from addr {:?}", &in_buf[..len], addr); + + Ok((len, addr)) } pub async fn send(&self, addr: SocketAddr, out_buf: &[u8]) -> Result { + self.0 + .send_to(out_buf, Self::from_socket_addr(addr)) + .await + .map_err(|e| { + warn!("Error on the network: {:?}", e); + ErrorCode::Network + })?; + debug!( "Send packet {:?} ({}/{}) to addr {:?}", out_buf, @@ -206,5 +307,27 @@ mod dummy_udp { Ok(out_buf.len()) } + + fn to_socket_addr(ep: IpEndpoint) -> SocketAddr { + SocketAddr::new(Self::to_ip_addr(ep.addr), ep.port) + } + + fn from_socket_addr(addr: SocketAddr) -> IpEndpoint { + IpEndpoint::new(Self::from_ip_addr(addr.ip()), addr.port()) + } + + fn to_ip_addr(ip: IpAddress) -> IpAddr { + match ip { + IpAddress::Ipv4(addr) => IpAddr::V4(Ipv4Addr::from(addr.0)), + IpAddress::Ipv6(addr) => IpAddr::V6(Ipv6Addr::from(addr.0)), + } + } + + fn from_ip_addr(ip: IpAddr) -> IpAddress { + match ip { + IpAddr::V4(v4) => IpAddress::Ipv4(Ipv4Address::from_bytes(&v4.octets())), + IpAddr::V6(v6) => IpAddress::Ipv6(Ipv6Address::from_bytes(&v6.octets())), + } + } } } From 0d73ba74ee67a2378873ddc685511ee22c46d8c6 Mon Sep 17 00:00:00 2001 From: ivmarkov Date: Fri, 14 Jul 2023 22:26:01 +0000 Subject: [PATCH 66/72] UDP stack based on embassy-net --- Cargo.toml | 1 - examples/onoff_light/src/main.rs | 119 ++++++++++--------------------- matter/Cargo.toml | 5 +- matter/src/mdns.rs | 14 ++-- matter/src/mdns/astro.rs | 44 ++++++++---- matter/src/mdns/builtin.rs | 74 ++++++++++--------- matter/src/transport/network.rs | 8 +-- matter/src/transport/runner.rs | 75 ++++++++++++++++--- matter/src/transport/udp.rs | 36 +++++----- 9 files changed, 209 insertions(+), 167 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 11f05af5..0ec37cec 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -5,7 +5,6 @@ exclude = ["examples/*"] # For compatibility with ESP IDF [patch.crates-io] -smol = { git = "https://github.com/esp-rs-compat/smol" } polling = { git = "https://github.com/esp-rs-compat/polling" } socket2 = { git = "https://github.com/esp-rs-compat/socket2" } diff --git a/examples/onoff_light/src/main.rs b/examples/onoff_light/src/main.rs index 93eab5f7..1b882ac9 100644 --- a/examples/onoff_light/src/main.rs +++ b/examples/onoff_light/src/main.rs @@ -18,7 +18,6 @@ use core::borrow::Borrow; use core::pin::pin; -use embassy_futures::select::select3; use log::info; use matter::core::{CommissioningData, Matter}; use matter::data_model::cluster_basic_information::BasicInfoConfig; @@ -28,13 +27,11 @@ use matter::data_model::objects::*; use matter::data_model::root_endpoint; use matter::data_model::system_model::descriptor; use matter::error::Error; -use matter::mdns::{DefaultMdns, DefaultMdnsRunner}; +use matter::mdns::MdnsService; use matter::persist::FilePsm; use matter::secure_channel::spake2p::VerifierData; use matter::transport::network::{Ipv4Addr, Ipv6Addr, NetworkStack}; -use matter::transport::runner::{RxBuf, TransportRunner, TxBuf}; -use matter::transport::udp::UdpBuffers; -use matter::utils::select::EitherUnwrap; +use matter::transport::runner::{AllUdpBuffers, TransportRunner}; mod dev_att; @@ -59,10 +56,11 @@ fn run() -> Result<(), Error> { initialize_logger(); info!( - "Matter memory: mDNS={}, Matter={}, TransportRunner={}", - core::mem::size_of::(), + "Matter memory: mDNS={}, Matter={}, TransportRunner={}, UdpBuffers={}", + core::mem::size_of::(), core::mem::size_of::(), core::mem::size_of::(), + core::mem::size_of::(), ); let dev_det = BasicInfoConfig { @@ -83,20 +81,6 @@ fn run() -> Result<(), Error> { let (ipv4_addr, ipv6_addr, interface) = initialize_network()?; - let mdns = DefaultMdns::new( - 0, - "matter-demo", - ipv4_addr.octets(), - Some(ipv6_addr.octets()), - interface, - &dev_det, - matter::MATTER_PORT, - ); - - let mut mdns_runner = DefaultMdnsRunner::new(&mdns); - - info!("mDNS initialized: {:p}, {:p}", &mdns, &mdns_runner); - let dev_att = dev_att::HardCodedDevAtt::new(); #[cfg(feature = "std")] @@ -113,6 +97,18 @@ fn run() -> Result<(), Error> { #[cfg(not(feature = "std"))] let rand = matter::utils::rand::dummy_rand; + let mdns = MdnsService::new( + 0, + "matter-demo", + ipv4_addr.octets(), + Some(ipv6_addr.octets()), + interface, + &dev_det, + matter::MATTER_PORT, + ); + + info!("mDNS initialized: {:p}", &mdns); + let matter = Matter::new( // vid/pid should match those in the DAC &dev_det, @@ -125,20 +121,6 @@ fn run() -> Result<(), Error> { info!("Matter initialized: {:p}", &matter); - let mut runner = TransportRunner::new(&matter); - - info!("Transport Runner initialized: {:p}", &runner); - - let mut tx_buf = TxBuf::uninit(); - let mut rx_buf = RxBuf::uninit(); - - // NOTE (no_std): If using the `embassy-net` UDP implementation, replace this dummy stack with the `embassy-net` one - // When using a custom UDP stack, remove this - let stack = NetworkStack::new(); - - let mut mdns_udp_buffers = UdpBuffers::new(); - let mut trans_udp_buffers = UdpBuffers::new(); - #[cfg(all(feature = "std", not(target_os = "espidf")))] { let mut buf = [0; 4096]; @@ -152,62 +134,33 @@ fn run() -> Result<(), Error> { } } - let node = Node { - id: 0, - endpoints: &[ - root_endpoint::endpoint(0), - Endpoint { - id: 1, - device_type: DEV_TYPE_ON_OFF_LIGHT, - clusters: &[descriptor::CLUSTER, cluster_on_off::CLUSTER], - }, - ], - }; + let mut runner = TransportRunner::new(&matter); + + info!("Transport runner initialized: {:p}", &runner); let handler = HandlerCompat(handler(&matter)); - let matter = &matter; - let node = &node; - let handler = &handler; - let runner = &mut runner; - let tx_buf = &mut tx_buf; - let rx_buf = &mut rx_buf; - let stack = &stack; - let mdns_udp_buffers = &mut mdns_udp_buffers; - let trans_udp_buffers = &mut trans_udp_buffers; + // NOTE (no_std): If using the `embassy-net` UDP implementation, replace this dummy stack with the `embassy-net` one + // When using a custom UDP stack, remove this + let stack = NetworkStack::new(); - info!( - "About to run wth node {:p}, handler {:p}, transport runner {:p}, mdns_runner {:p}", - node, handler, runner, &mdns_runner - ); + let mut buffers = AllUdpBuffers::new(); - let mut fut = pin!(async move { - // NOTE: If using a custom UDP stack, replace `run_udp` with `run` - // and connect the pipes of the `run` method with the custom UDP stack - let mut transport = pin!(runner.run_udp( - stack, - trans_udp_buffers, - tx_buf, - rx_buf, - CommissioningData { - // TODO: Hard-coded for now - verifier: VerifierData::new_with_pw(123456, *matter.borrow()), - discriminator: 250, - }, - &handler, - )); - - // NOTE: If using a custom UDP stack, replace `run_udp` with `run` - // and connect the pipes of the `run` method with the custom UDP stack - let mut mdns = pin!(mdns_runner.run_udp(stack, mdns_udp_buffers)); - - let mut save = pin!(save(matter, &psm)); - select3(&mut transport, &mut mdns, &mut save).await.unwrap() - }); + let mut fut = pin!(runner.run_udp_all( + &stack, + &mdns, + &mut buffers, + CommissioningData { + // TODO: Hard-coded for now + verifier: VerifierData::new_with_pw(123456, *matter.borrow()), + discriminator: 250, + }, + &handler, + )); // NOTE: For no_std, replace with your own no_std way of polling the future #[cfg(feature = "std")] - smol::block_on(&mut fut)?; + async_io::block_on(&mut fut)?; // NOTE (no_std): For no_std, replace with your own more efficient no_std executor, // because the executor used below is a simple busy-loop poller diff --git a/matter/Cargo.toml b/matter/Cargo.toml index 87533b5d..46e16c99 100644 --- a/matter/Cargo.toml +++ b/matter/Cargo.toml @@ -17,7 +17,7 @@ path = "src/lib.rs" [features] default = ["os", "crypto_rustcrypto"] os = ["std", "backtrace", "env_logger", "nix", "critical-section/std", "embassy-sync/std", "embassy-time/std"] -std = ["alloc", "rand", "qrcode", "async-io", "smol", "esp-idf-sys/std"] +std = ["alloc", "rand", "qrcode", "async-io", "esp-idf-sys/std"] backtrace = [] alloc = [] nightly = [] @@ -47,6 +47,8 @@ embassy-time = { version = "0.1.1", features = ["generic-queue-8"] } embassy-sync = "0.2" critical-section = "1.1.1" domain = { version = "0.7.2", default_features = false, features = ["heapless"] } + +# embassy-net dependencies embassy-net = { version = "0.1", features = ["udp", "igmp", "proto-ipv6", "medium-ethernet", "medium-ip"], optional = true } embassy-net-driver = { version = "0.1", optional = true } smoltcp = { version = "0.10", default-features = false, optional = true } @@ -54,7 +56,6 @@ smoltcp = { version = "0.10", default-features = false, optional = true } # STD-only dependencies rand = { version = "0.8.5", optional = true } qrcode = { version = "0.12", default-features = false, optional = true } # Print QR code -smol = { version = "1.2", optional = true } # =1.2 for compatibility with ESP IDF async-io = { version = "=1.12", optional = true } # =1.2 for compatibility with ESP IDF # crypto diff --git a/matter/src/mdns.rs b/matter/src/mdns.rs index c2ae5385..47a08614 100644 --- a/matter/src/mdns.rs +++ b/matter/src/mdns.rs @@ -56,16 +56,18 @@ where } #[cfg(all(feature = "std", target_os = "macos"))] -pub type DefaultMdns<'a> = astro::Mdns<'a>; - +pub use astro::MdnsRunner; +#[cfg(all(feature = "std", target_os = "macos"))] +pub use astro::MdnsService; #[cfg(all(feature = "std", target_os = "macos"))] -pub type DefaultMdnsRunner<'a> = astro::MdnsRunner<'a>; +pub use astro::MdnsUdpBuffers; #[cfg(not(all(feature = "std", target_os = "macos")))] -pub type DefaultMdns<'a> = builtin::Mdns<'a>; - +pub use builtin::MdnsRunner; +#[cfg(not(all(feature = "std", target_os = "macos")))] +pub use builtin::MdnsService; #[cfg(not(all(feature = "std", target_os = "macos")))] -pub type DefaultMdnsRunner<'a> = builtin::MdnsRunner<'a>; +pub use builtin::MdnsUdpBuffers; pub struct DummyMdns; diff --git a/matter/src/mdns/astro.rs b/matter/src/mdns/astro.rs index e7ae4c20..857fb467 100644 --- a/matter/src/mdns/astro.rs +++ b/matter/src/mdns/astro.rs @@ -11,13 +11,14 @@ use log::info; use super::ServiceMode; -pub struct Mdns<'a> { +pub struct MdnsService<'a> { dev_det: &'a BasicInfoConfig<'a>, matter_port: u16, services: RefCell>, } -impl<'a> Mdns<'a> { +impl<'a> MdnsService<'a> { + /// This constructor takes extra parameters for API-compatibility with builtin::MdnsRunner pub fn new( _id: u16, _hostname: &str, @@ -80,28 +81,41 @@ impl<'a> Mdns<'a> { } } -pub struct MdnsRunner<'a>(&'a Mdns<'a>); +/// Only for API-compatibility with builtin::MdnsRunner +pub struct MdnsUdpBuffers(()); -impl<'a> MdnsRunner<'a> { - pub const fn new(mdns: &'a Mdns<'a>) -> Self { - Self(mdns) +/// Only for API-compatibility with builtin::MdnsRunner +impl MdnsUdpBuffers { + #[inline(always)] + pub const fn new() -> Self { + Self(()) } +} - pub async fn run_udp(&mut self) -> Result<(), Error> { - core::future::pending::>().await +impl<'a> super::Mdns for MdnsService<'a> { + fn add(&self, service: &str, mode: ServiceMode) -> Result<(), Error> { + MdnsService::add(self, service, mode) } - pub async fn run(&self, _tx_pipe: &Pipe<'_>, _rx_pipe: &Pipe<'_>) -> Result<(), Error> { - core::future::pending::>().await + fn remove(&self, service: &str) -> Result<(), Error> { + MdnsService::remove(self, service) } } -impl<'a> super::Mdns for Mdns<'a> { - fn add(&self, service: &str, mode: ServiceMode) -> Result<(), Error> { - Mdns::add(self, service, mode) +/// Only for API-compatibility with builtin::MdnsRunner +pub struct MdnsRunner<'a>(&'a MdnsService<'a>); + +/// Only for API-compatibility with builtin::MdnsRunner +impl<'a> MdnsRunner<'a> { + pub const fn new(mdns: &'a MdnsService<'a>) -> Self { + Self(mdns) } - fn remove(&self, service: &str) -> Result<(), Error> { - Mdns::remove(self, service) + pub async fn run_udp(&mut self, buffers: &mut MdnsUdpBuffers) -> Result<(), Error> { + core::future::pending::>().await + } + + pub async fn run(&self, _tx_pipe: &Pipe<'_>, _rx_pipe: &Pipe<'_>) -> Result<(), Error> { + core::future::pending::>().await } } diff --git a/matter/src/mdns/builtin.rs b/matter/src/mdns/builtin.rs index c869218d..f4a5c6e6 100644 --- a/matter/src/mdns/builtin.rs +++ b/matter/src/mdns/builtin.rs @@ -22,7 +22,7 @@ const IPV6_BROADCAST_ADDR: Ipv6Addr = Ipv6Addr::new(0xff02, 0, 0, 0, 0, 0, 0, 0x const PORT: u16 = 5353; -pub struct Mdns<'a> { +pub struct MdnsService<'a> { host: Host<'a>, #[allow(unused)] interface: u32, @@ -32,7 +32,7 @@ pub struct Mdns<'a> { notification: Notification, } -impl<'a> Mdns<'a> { +impl<'a> MdnsService<'a> { #[inline(always)] pub const fn new( id: u16, @@ -95,10 +95,29 @@ impl<'a> Mdns<'a> { } } -pub struct MdnsRunner<'a>(&'a Mdns<'a>); +#[cfg(any(feature = "std", feature = "embassy-net"))] +pub struct MdnsUdpBuffers { + udp: crate::transport::udp::UdpBuffers, + tx_buf: core::mem::MaybeUninit<[u8; crate::transport::packet::MAX_TX_BUF_SIZE]>, + rx_buf: core::mem::MaybeUninit<[u8; crate::transport::packet::MAX_RX_BUF_SIZE]>, +} + +#[cfg(any(feature = "std", feature = "embassy-net"))] +impl MdnsUdpBuffers { + #[inline(always)] + pub const fn new() -> Self { + Self { + udp: crate::transport::udp::UdpBuffers::new(), + tx_buf: core::mem::MaybeUninit::uninit(), + rx_buf: core::mem::MaybeUninit::uninit(), + } + } +} + +pub struct MdnsRunner<'a>(&'a MdnsService<'a>); impl<'a> MdnsRunner<'a> { - pub const fn new(mdns: &'a Mdns<'a>) -> Self { + pub const fn new(mdns: &'a MdnsService<'a>) -> Self { Self(mdns) } @@ -106,31 +125,17 @@ impl<'a> MdnsRunner<'a> { pub async fn run_udp( &mut self, stack: &crate::transport::network::NetworkStack, - buffers: &mut crate::transport::udp::UdpBuffers, + buffers: &mut MdnsUdpBuffers, ) -> Result<(), Error> where D: crate::transport::network::NetworkStackMulticastDriver + crate::transport::network::NetworkStackDriver + 'static, { - let mut tx_buf = - core::mem::MaybeUninit::<[u8; crate::transport::packet::MAX_TX_BUF_SIZE]>::uninit(); - let mut rx_buf = - core::mem::MaybeUninit::<[u8; crate::transport::packet::MAX_RX_BUF_SIZE]>::uninit(); - - let tx_buf = &mut tx_buf; - let rx_buf = &mut rx_buf; - - let tx_pipe = Pipe::new(unsafe { tx_buf.assume_init_mut() }); - let rx_pipe = Pipe::new(unsafe { rx_buf.assume_init_mut() }); - - let tx_pipe = &tx_pipe; - let rx_pipe = &rx_pipe; - let mut udp = crate::transport::udp::UdpListener::new( stack, crate::transport::network::SocketAddr::new(IpAddr::V6(Ipv6Addr::UNSPECIFIED), PORT), - buffers, + &mut buffers.udp, ) .await?; @@ -140,6 +145,11 @@ impl<'a> MdnsRunner<'a> { crate::transport::network::Ipv4Addr::from(self.0.host.ip), )?; + let tx_pipe = Pipe::new(unsafe { buffers.tx_buf.assume_init_mut() }); + let rx_pipe = Pipe::new(unsafe { buffers.rx_buf.assume_init_mut() }); + + let tx_pipe = &tx_pipe; + let rx_pipe = &rx_pipe; let udp = &udp; let mut tx = pin!(async move { @@ -295,24 +305,24 @@ impl<'a> MdnsRunner<'a> { } } -impl<'a> super::Mdns for Mdns<'a> { - fn add(&self, service: &str, mode: ServiceMode) -> Result<(), Error> { - Mdns::add(self, service, mode) - } - - fn remove(&self, service: &str) -> Result<(), Error> { - Mdns::remove(self, service) - } -} - -impl<'a> Services for Mdns<'a> { +impl<'a> Services for MdnsService<'a> { type Error = crate::error::Error; fn for_each(&self, callback: F) -> Result<(), Error> where F: FnMut(&Service) -> Result<(), Error>, { - Mdns::for_each(self, callback) + MdnsService::for_each(self, callback) + } +} + +impl<'a> super::Mdns for MdnsService<'a> { + fn add(&self, service: &str, mode: ServiceMode) -> Result<(), Error> { + MdnsService::add(self, service, mode) + } + + fn remove(&self, service: &str) -> Result<(), Error> { + MdnsService::remove(self, service) } } diff --git a/matter/src/transport/network.rs b/matter/src/transport/network.rs index c3b71ee8..21cfd726 100644 --- a/matter/src/transport/network.rs +++ b/matter/src/transport/network.rs @@ -60,10 +60,10 @@ impl Debug for Address { pub use std_stack::*; #[cfg(feature = "embassy-net")] -pub use embassy_stack::*; +pub use embassy_net_stack::*; -#[cfg(all(feature = "std", not(feature = "embassy-net")))] -mod std_stack { +#[cfg(feature = "std")] +pub mod std_stack { pub trait NetworkStackDriver {} impl NetworkStackDriver for () {} @@ -82,7 +82,7 @@ mod std_stack { } #[cfg(feature = "embassy-net")] -mod embassy_stack { +pub mod embassy_net_stack { pub use embassy_net::Stack as NetworkStack; pub use embassy_net_driver::Driver as NetworkStackDriver; pub use smoltcp::phy::Device as NetworkStackMulticastDriver; diff --git a/matter/src/transport/runner.rs b/matter/src/transport/runner.rs index ccb1034a..373021d4 100644 --- a/matter/src/transport/runner.rs +++ b/matter/src/transport/runner.rs @@ -41,8 +41,8 @@ use super::{ pipe::{Chunk, Pipe}, }; -pub type TxBuf = MaybeUninit<[u8; MAX_TX_BUF_SIZE]>; -pub type RxBuf = MaybeUninit<[u8; MAX_RX_BUF_SIZE]>; +type TxBuf = MaybeUninit<[u8; MAX_TX_BUF_SIZE]>; +type RxBuf = MaybeUninit<[u8; MAX_RX_BUF_SIZE]>; type SxBuf = MaybeUninit<[u8; MAX_RX_STATUS_BUF_SIZE]>; struct PacketPools { @@ -70,6 +70,42 @@ impl PacketPools { } } +#[cfg(any(feature = "std", feature = "embassy-net"))] +pub struct AllUdpBuffers { + transport: TransportUdpBuffers, + mdns: crate::mdns::MdnsUdpBuffers, +} + +#[cfg(any(feature = "std", feature = "embassy-net"))] +impl AllUdpBuffers { + #[inline(always)] + pub const fn new() -> Self { + Self { + transport: TransportUdpBuffers::new(), + mdns: crate::mdns::MdnsUdpBuffers::new(), + } + } +} + +#[cfg(any(feature = "std", feature = "embassy-net"))] +pub struct TransportUdpBuffers { + udp: crate::transport::udp::UdpBuffers, + tx_buf: TxBuf, + rx_buf: RxBuf, +} + +#[cfg(any(feature = "std", feature = "embassy-net"))] +impl TransportUdpBuffers { + #[inline(always)] + pub const fn new() -> Self { + Self { + udp: crate::transport::udp::UdpBuffers::new(), + tx_buf: core::mem::MaybeUninit::uninit(), + rx_buf: core::mem::MaybeUninit::uninit(), + } + } +} + /// This struct implements an executor-agnostic option to run the Matter transport stack end-to-end. /// /// Since it is not possible to use executor tasks spawning in an executor-agnostic way (yet), @@ -101,13 +137,36 @@ impl<'a> TransportRunner<'a> { &self.transport } + #[cfg(any(feature = "std", feature = "embassy-net"))] + pub async fn run_udp_all( + &mut self, + stack: &crate::transport::network::NetworkStack, + mdns: &crate::mdns::MdnsService<'_>, + buffers: &mut AllUdpBuffers, + dev_comm: CommissioningData, + handler: &H, + ) -> Result<(), Error> + where + D: crate::transport::network::NetworkStackDriver + + crate::transport::network::NetworkStackMulticastDriver + + 'static, + H: DataModelHandler, + { + let mut mdns_runner = crate::mdns::MdnsRunner::new(mdns); + + let mut mdns = pin!(mdns_runner.run_udp(stack, &mut buffers.mdns)); + let mut transport = pin!(self.run_udp(stack, &mut buffers.transport, dev_comm, handler)); + + embassy_futures::select::select(&mut transport, &mut mdns) + .await + .unwrap() + } + #[cfg(any(feature = "std", feature = "embassy-net"))] pub async fn run_udp( &mut self, stack: &crate::transport::network::NetworkStack, - buffers: &mut crate::transport::udp::UdpBuffers, - tx_buf: &mut TxBuf, - rx_buf: &mut RxBuf, + buffers: &mut TransportUdpBuffers, dev_comm: CommissioningData, handler: &H, ) -> Result<(), Error> @@ -123,12 +182,12 @@ impl<'a> TransportRunner<'a> { ), self.transport.matter().port, ), - buffers, + &mut buffers.udp, ) .await?; - let tx_pipe = Pipe::new(unsafe { tx_buf.assume_init_mut() }); - let rx_pipe = Pipe::new(unsafe { rx_buf.assume_init_mut() }); + let tx_pipe = Pipe::new(unsafe { buffers.tx_buf.assume_init_mut() }); + let rx_pipe = Pipe::new(unsafe { buffers.rx_buf.assume_init_mut() }); let tx_pipe = &tx_pipe; let rx_pipe = &rx_pipe; diff --git a/matter/src/transport/udp.rs b/matter/src/transport/udp.rs index e9e5811d..a9e24095 100644 --- a/matter/src/transport/udp.rs +++ b/matter/src/transport/udp.rs @@ -16,21 +16,25 @@ */ #[cfg(all(feature = "std", not(feature = "embassy-net")))] -pub use smol_udp::*; +pub use async_io::*; #[cfg(feature = "embassy-net")] -pub use embassy_udp::*; +pub use embassy_net::*; -#[cfg(all(feature = "std", not(feature = "embassy-net")))] -mod smol_udp { +#[cfg(feature = "std")] +pub mod async_io { use crate::error::*; + + use std::net::UdpSocket; + + use async_io::Async; + use log::{debug, info, warn}; - use smol::net::UdpSocket; - use crate::transport::network::{ - Ipv4Addr, Ipv6Addr, NetworkStack, NetworkStackDriver, NetworkStackMulticastDriver, - SocketAddr, + use crate::transport::network::std_stack::{ + NetworkStack, NetworkStackDriver, NetworkStackMulticastDriver, }; + use crate::transport::network::{Ipv4Addr, Ipv6Addr, SocketAddr}; pub struct UdpBuffers(()); @@ -40,7 +44,7 @@ mod smol_udp { } } - pub struct UdpListener<'a, D>(UdpSocket, &'a NetworkStack) + pub struct UdpListener<'a, D>(Async, &'a NetworkStack) where D: NetworkStackDriver; @@ -53,7 +57,7 @@ mod smol_udp { addr: SocketAddr, _buffers: &'a mut UdpBuffers, ) -> Result, Error> { - let listener = UdpListener(UdpSocket::bind((addr.ip(), addr.port())).await?, stack); + let listener = UdpListener(Async::::bind(addr)?, stack); info!("Listening on {:?}", addr); @@ -68,7 +72,7 @@ mod smol_udp { where D: NetworkStackMulticastDriver + 'static, { - self.0.join_multicast_v6(&multiaddr, interface)?; + self.0.get_ref().join_multicast_v6(&multiaddr, interface)?; info!("Joined IPV6 multicast {}/{}", multiaddr, interface); @@ -84,7 +88,7 @@ mod smol_udp { D: NetworkStackMulticastDriver + 'static, { #[cfg(not(target_os = "espidf"))] - self.0.join_multicast_v4(multiaddr, interface)?; + self.0.get_ref().join_multicast_v4(&multiaddr, &interface)?; // join_multicast_v4() is broken for ESP-IDF, most likely due to wrong `ip_mreq` signature in the `libc` crate // Note that also most *_multicast_v4 and *_multicast_v6 methods are broken as well in Rust STD for the ESP-IDF @@ -166,7 +170,7 @@ mod smol_udp { } #[cfg(feature = "embassy-net")] -mod embassy_udp { +pub mod embassy_net { use core::mem::MaybeUninit; use embassy_net::udp::{PacketMetadata, UdpSocket}; @@ -177,10 +181,10 @@ mod embassy_udp { use log::{debug, info, warn}; - use crate::transport::network::{ - IpAddr, Ipv4Addr, Ipv6Addr, NetworkStack, NetworkStackDriver, NetworkStackMulticastDriver, - SocketAddr, + use crate::transport::network::embassy_net_stack::{ + NetworkStack, NetworkStackDriver, NetworkStackMulticastDriver, }; + use crate::transport::network::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr}; const RX_BUF_SIZE: usize = 4096; const TX_BUF_SIZE: usize = 4096; From 24cdf079a6e18261e99589cdb85a1093cd3b66c8 Mon Sep 17 00:00:00 2001 From: ivmarkov Date: Mon, 17 Jul 2023 16:53:14 +0000 Subject: [PATCH 67/72] New helper APIs in Transport --- matter/src/transport/core.rs | 177 ++++++++++++++++++++++++++++++- matter/src/transport/runner.rs | 184 ++++----------------------------- 2 files changed, 196 insertions(+), 165 deletions(-) diff --git a/matter/src/transport/core.rs b/matter/src/transport/core.rs index 98c2fbab..1a51b916 100644 --- a/matter/src/transport/core.rs +++ b/matter/src/transport/core.rs @@ -17,13 +17,23 @@ use core::{borrow::Borrow, cell::RefCell}; -use crate::{error::ErrorCode, secure_channel::common::OpCode, Matter}; use embassy_futures::select::select; +use embassy_sync::{blocking_mutex::raw::NoopRawMutex, channel::Channel}; use embassy_time::{Duration, Timer}; -use log::info; + +use log::{error, info, warn}; use crate::{ - error::Error, secure_channel::common::PROTO_ID_SECURE_CHANNEL, transport::packet::Packet, + alloc, + data_model::{core::DataModel, objects::DataModelHandler}, + error::{Error, ErrorCode}, + interaction_model::core::PROTO_ID_INTERACTION_MODEL, + secure_channel::{ + common::{OpCode, PROTO_ID_SECURE_CHANNEL}, + core::SecureChannel, + }, + transport::packet::Packet, + Matter, }; use super::{ @@ -32,6 +42,8 @@ use super::{ MAX_EXCHANGES, }, mrp::ReliableMessage, + packet::{MAX_RX_BUF_SIZE, MAX_RX_STATUS_BUF_SIZE, MAX_TX_BUF_SIZE}, + pipe::{Chunk, Pipe}, session::SessionMgr, }; @@ -83,6 +95,165 @@ impl<'a> Transport<'a> { unimplemented!() } + #[inline(always)] + pub async fn handle_tx(&self, tx_pipe: &Pipe<'_>) -> Result<(), Error> { + loop { + loop { + { + let mut data = tx_pipe.data.lock().await; + + if data.chunk.is_none() { + let mut tx = alloc!(Packet::new_tx(data.buf)); + + if self.pull_tx(&mut tx).await? { + data.chunk = Some(Chunk { + start: tx.get_writebuf()?.get_start(), + end: tx.get_writebuf()?.get_tail(), + addr: tx.peer, + }); + tx_pipe.data_supplied_notification.signal(()); + } else { + break; + } + } + } + + tx_pipe.data_consumed_notification.wait().await; + } + + self.wait_tx().await?; + } + } + + #[inline(always)] + pub async fn handle_rx_multiplex<'t, 'e, const N: usize>( + &'t self, + rx_pipe: &Pipe<'_>, + construction_notification: &'e Notification, + channel: &Channel, N>, + ) -> Result<(), Error> + where + 't: 'e, + { + loop { + info!("Transport: waiting for incoming packets"); + + { + let mut data = rx_pipe.data.lock().await; + + if let Some(chunk) = data.chunk { + let mut rx = alloc!(Packet::new_rx(&mut data.buf[chunk.start..chunk.end])); + rx.peer = chunk.addr; + + if let Some(exchange_ctr) = + self.process_rx(construction_notification, &mut rx)? + { + let exchange_id = exchange_ctr.id().clone(); + + info!("Transport: got new exchange: {:?}", exchange_id); + + channel.send(exchange_ctr).await; + info!("Transport: exchange sent"); + + self.wait_construction(construction_notification, &rx, &exchange_id) + .await?; + + info!("Transport: exchange started"); + } + + data.chunk = None; + rx_pipe.data_consumed_notification.signal(()); + } + } + + rx_pipe.data_supplied_notification.wait().await + } + + #[allow(unreachable_code)] + Ok::<_, Error>(()) + } + + #[inline(always)] + pub async fn exchange_handler( + &self, + tx_buf: &mut [u8; MAX_TX_BUF_SIZE], + rx_buf: &mut [u8; MAX_RX_BUF_SIZE], + sx_buf: &mut [u8; MAX_RX_STATUS_BUF_SIZE], + handler_id: impl core::fmt::Display, + channel: &Channel, N>, + handler: &H, + ) -> Result<(), Error> + where + H: DataModelHandler, + { + loop { + let exchange_ctr: ExchangeCtr<'_> = channel.recv().await; + + info!( + "Handler {}: Got exchange {:?}", + handler_id, + exchange_ctr.id() + ); + + let result = self + .handle_exchange(tx_buf, rx_buf, sx_buf, exchange_ctr, handler) + .await; + + if let Err(err) = result { + warn!( + "Handler {}: Exchange closed because of error: {:?}", + handler_id, err + ); + } else { + info!("Handler {}: Exchange completed", handler_id); + } + } + } + + #[inline(always)] + #[cfg_attr(feature = "nightly", allow(clippy::await_holding_refcell_ref))] // Fine because of the async mutex + pub async fn handle_exchange( + &self, + tx_buf: &mut [u8; MAX_TX_BUF_SIZE], + rx_buf: &mut [u8; MAX_RX_BUF_SIZE], + sx_buf: &mut [u8; MAX_RX_STATUS_BUF_SIZE], + exchange_ctr: ExchangeCtr<'_>, + handler: &H, + ) -> Result<(), Error> + where + H: DataModelHandler, + { + let mut tx = alloc!(Packet::new_tx(tx_buf.as_mut())); + let mut rx = alloc!(Packet::new_rx(rx_buf.as_mut())); + + let mut exchange = alloc!(exchange_ctr.get(&mut rx).await?); + + match rx.get_proto_id() { + PROTO_ID_SECURE_CHANNEL => { + let sc = SecureChannel::new(self.matter()); + + sc.handle(&mut exchange, &mut rx, &mut tx).await?; + + self.matter().notify_changed(); + } + PROTO_ID_INTERACTION_MODEL => { + let dm = DataModel::new(handler); + + let mut rx_status = alloc!(Packet::new_rx(sx_buf)); + + dm.handle(&mut exchange, &mut rx, &mut tx, &mut rx_status) + .await?; + + self.matter().notify_changed(); + } + other => { + error!("Unknown Proto-ID: {}", other); + } + } + + Ok(()) + } + pub fn process_rx<'r>( &'r self, construction_notification: &'r Notification, diff --git a/matter/src/transport/runner.rs b/matter/src/transport/runner.rs index 373021d4..d46b3e3a 100644 --- a/matter/src/transport/runner.rs +++ b/matter/src/transport/runner.rs @@ -17,26 +17,17 @@ use core::{mem::MaybeUninit, pin::pin}; -use crate::{ - alloc, - data_model::{core::DataModel, objects::DataModelHandler}, - interaction_model::core::PROTO_ID_INTERACTION_MODEL, - CommissioningData, Matter, -}; use embassy_futures::select::{select, select_slice, Either}; use embassy_sync::{blocking_mutex::raw::NoopRawMutex, channel::Channel}; -use log::{error, info, warn}; -use crate::{ - error::Error, - secure_channel::{common::PROTO_ID_SECURE_CHANNEL, core::SecureChannel}, - transport::packet::{Packet, MAX_RX_BUF_SIZE}, - utils::select::EitherUnwrap, -}; +use log::{error, info}; + +use crate::{data_model::objects::DataModelHandler, CommissioningData, Matter}; +use crate::{error::Error, transport::packet::MAX_RX_BUF_SIZE, utils::select::EitherUnwrap}; use super::{ core::Transport, - exchange::{ExchangeCtr, Notification, MAX_EXCHANGES}, + exchange::{Notification, MAX_EXCHANGES}, packet::{MAX_RX_STATUS_BUF_SIZE, MAX_TX_BUF_SIZE}, pipe::{Chunk, Pipe}, }; @@ -157,7 +148,7 @@ impl<'a> TransportRunner<'a> { let mut mdns = pin!(mdns_runner.run_udp(stack, &mut buffers.mdns)); let mut transport = pin!(self.run_udp(stack, &mut buffers.transport, dev_comm, handler)); - embassy_futures::select::select(&mut transport, &mut mdns) + embassy_futures::select::select(&mut mdns, &mut transport) .await .unwrap() } @@ -265,11 +256,12 @@ impl<'a> TransportRunner<'a> { &construction_notification, handler )); - let mut tx = pin!(Self::handle_tx(&self.transport, tx_pipe)); + let mut tx = pin!(self.transport.handle_tx(tx_pipe)); select(&mut rx, &mut tx).await.unwrap() } + #[inline(always)] async fn handle_rx( transport: &Transport<'_>, pools: &mut PacketPools, @@ -289,85 +281,30 @@ impl<'a> TransportRunner<'a> { info!("Handlers size: {}", core::mem::size_of_val(&handlers)); - let pools = &mut *pools as *mut _; + // Unsafely allow mutable aliasing in the packet pools by different indices + let pools: *mut PacketPools = pools; for index in 0..MAX_EXCHANGES { let channel = &channel; let handler_id = index; + let pools = unsafe { pools.as_mut() }.unwrap(); + + let tx_buf = unsafe { pools.tx[handler_id].assume_init_mut() }; + let rx_buf = unsafe { pools.rx[handler_id].assume_init_mut() }; + let sx_buf = unsafe { pools.sx[handler_id].assume_init_mut() }; + handlers - .push(async move { - loop { - let exchange_ctr: ExchangeCtr<'_> = channel.recv().await; - - info!( - "Handler {}: Got exchange {:?}", - handler_id, - exchange_ctr.id() - ); - - let result = Self::handle_exchange( - transport, - pools, - handler_id, - exchange_ctr, - handler, - ) - .await; - - if let Err(err) = result { - warn!( - "Handler {}: Exchange closed because of error: {:?}", - handler_id, err - ); - } else { - info!("Handler {}: Exchange completed", handler_id); - } - } - }) + .push( + transport + .exchange_handler(tx_buf, rx_buf, sx_buf, handler_id, channel, handler), + ) .map_err(|_| ()) .unwrap(); } - let mut rx = pin!(async { - loop { - info!("Transport: waiting for incoming packets"); - - { - let mut data = rx_pipe.data.lock().await; - - if let Some(chunk) = data.chunk { - let mut rx = alloc!(Packet::new_rx(&mut data.buf[chunk.start..chunk.end])); - rx.peer = chunk.addr; - - if let Some(exchange_ctr) = - transport.process_rx(construction_notification, &mut rx)? - { - let exchange_id = exchange_ctr.id().clone(); - - info!("Transport: got new exchange: {:?}", exchange_id); - - channel.send(exchange_ctr).await; - info!("Transport: exchange sent"); - - transport - .wait_construction(construction_notification, &rx, &exchange_id) - .await?; - - info!("Transport: exchange started"); - } - - data.chunk = None; - rx_pipe.data_consumed_notification.signal(()); - } - } - - rx_pipe.data_supplied_notification.wait().await - } - - #[allow(unreachable_code)] - Ok::<_, Error>(()) - }); + let mut rx = + pin!(transport.handle_rx_multiplex(rx_pipe, &construction_notification, &channel)); let result = select(&mut rx, select_slice(&mut handlers)).await; @@ -381,81 +318,4 @@ impl<'a> TransportRunner<'a> { Ok(()) } - - async fn handle_tx(transport: &Transport<'_>, tx_pipe: &Pipe<'_>) -> Result<(), Error> { - loop { - loop { - { - let mut data = tx_pipe.data.lock().await; - - if data.chunk.is_none() { - let mut tx = alloc!(Packet::new_tx(data.buf)); - - if transport.pull_tx(&mut tx).await? { - data.chunk = Some(Chunk { - start: tx.get_writebuf()?.get_start(), - end: tx.get_writebuf()?.get_tail(), - addr: tx.peer, - }); - tx_pipe.data_supplied_notification.signal(()); - } else { - break; - } - } - } - - tx_pipe.data_consumed_notification.wait().await; - } - - transport.wait_tx().await?; - } - } - - #[cfg_attr(feature = "nightly", allow(clippy::await_holding_refcell_ref))] // Fine because of the async mutex - async fn handle_exchange( - transport: &Transport<'_>, - pools: *mut PacketPools, - handler_id: usize, - exchange_ctr: ExchangeCtr<'_>, - handler: &H, - ) -> Result<(), Error> - where - H: DataModelHandler, - { - let pools = unsafe { pools.as_mut() }.unwrap(); - - let tx_buf = unsafe { pools.tx[handler_id].assume_init_mut() }; - let rx_buf = unsafe { pools.rx[handler_id].assume_init_mut() }; - let rx_status_buf = unsafe { pools.sx[handler_id].assume_init_mut() }; - - let mut rx = alloc!(Packet::new_rx(rx_buf.as_mut())); - let mut tx = alloc!(Packet::new_tx(tx_buf.as_mut())); - - let mut exchange = alloc!(exchange_ctr.get(&mut rx).await?); - - match rx.get_proto_id() { - PROTO_ID_SECURE_CHANNEL => { - let sc = SecureChannel::new(transport.matter()); - - sc.handle(&mut exchange, &mut rx, &mut tx).await?; - - transport.matter().notify_changed(); - } - PROTO_ID_INTERACTION_MODEL => { - let dm = DataModel::new(handler); - - let mut rx_status = alloc!(Packet::new_rx(rx_status_buf)); - - dm.handle(&mut exchange, &mut rx, &mut tx, &mut rx_status) - .await?; - - transport.matter().notify_changed(); - } - other => { - error!("Unknown Proto-ID: {}", other); - } - } - - Ok(()) - } } From aa2d5dfe2038f6085ad939edf4de1cc2e162d59f Mon Sep 17 00:00:00 2001 From: ivmarkov Date: Mon, 17 Jul 2023 20:26:50 +0000 Subject: [PATCH 68/72] Compatibility with embassy-net fixed multicast support --- matter/src/mdns/builtin.rs | 10 +++---- matter/src/transport/network.rs | 5 ---- matter/src/transport/runner.rs | 4 +-- matter/src/transport/udp.rs | 46 ++++++++++++--------------------- 4 files changed, 22 insertions(+), 43 deletions(-) diff --git a/matter/src/mdns/builtin.rs b/matter/src/mdns/builtin.rs index f4a5c6e6..c3049a51 100644 --- a/matter/src/mdns/builtin.rs +++ b/matter/src/mdns/builtin.rs @@ -128,9 +128,7 @@ impl<'a> MdnsRunner<'a> { buffers: &mut MdnsUdpBuffers, ) -> Result<(), Error> where - D: crate::transport::network::NetworkStackMulticastDriver - + crate::transport::network::NetworkStackDriver - + 'static, + D: crate::transport::network::NetworkStackDriver, { let mut udp = crate::transport::udp::UdpListener::new( stack, @@ -139,11 +137,13 @@ impl<'a> MdnsRunner<'a> { ) .await?; - udp.join_multicast_v6(IPV6_BROADCAST_ADDR, self.0.interface)?; + udp.join_multicast_v6(IPV6_BROADCAST_ADDR, self.0.interface) + .await?; udp.join_multicast_v4( IP_BROADCAST_ADDR, crate::transport::network::Ipv4Addr::from(self.0.host.ip), - )?; + ) + .await?; let tx_pipe = Pipe::new(unsafe { buffers.tx_buf.assume_init_mut() }); let rx_pipe = Pipe::new(unsafe { buffers.rx_buf.assume_init_mut() }); diff --git a/matter/src/transport/network.rs b/matter/src/transport/network.rs index 21cfd726..850dde31 100644 --- a/matter/src/transport/network.rs +++ b/matter/src/transport/network.rs @@ -68,10 +68,6 @@ pub mod std_stack { impl NetworkStackDriver for () {} - pub trait NetworkStackMulticastDriver {} - - impl NetworkStackMulticastDriver for () {} - pub struct NetworkStack(D); impl NetworkStack<()> { @@ -85,5 +81,4 @@ pub mod std_stack { pub mod embassy_net_stack { pub use embassy_net::Stack as NetworkStack; pub use embassy_net_driver::Driver as NetworkStackDriver; - pub use smoltcp::phy::Device as NetworkStackMulticastDriver; } diff --git a/matter/src/transport/runner.rs b/matter/src/transport/runner.rs index d46b3e3a..554721b0 100644 --- a/matter/src/transport/runner.rs +++ b/matter/src/transport/runner.rs @@ -138,9 +138,7 @@ impl<'a> TransportRunner<'a> { handler: &H, ) -> Result<(), Error> where - D: crate::transport::network::NetworkStackDriver - + crate::transport::network::NetworkStackMulticastDriver - + 'static, + D: crate::transport::network::NetworkStackDriver, H: DataModelHandler, { let mut mdns_runner = crate::mdns::MdnsRunner::new(mdns); diff --git a/matter/src/transport/udp.rs b/matter/src/transport/udp.rs index a9e24095..3d27d2da 100644 --- a/matter/src/transport/udp.rs +++ b/matter/src/transport/udp.rs @@ -31,9 +31,7 @@ pub mod async_io { use log::{debug, info, warn}; - use crate::transport::network::std_stack::{ - NetworkStack, NetworkStackDriver, NetworkStackMulticastDriver, - }; + use crate::transport::network::std_stack::{NetworkStack, NetworkStackDriver}; use crate::transport::network::{Ipv4Addr, Ipv6Addr, SocketAddr}; pub struct UdpBuffers(()); @@ -46,11 +44,11 @@ pub mod async_io { pub struct UdpListener<'a, D>(Async, &'a NetworkStack) where - D: NetworkStackDriver; + D: NetworkStackDriver + 'static; impl<'a, D> UdpListener<'a, D> where - D: NetworkStackDriver + 'a, + D: NetworkStackDriver + 'a + 'static, { pub async fn new( stack: &'a NetworkStack, @@ -64,14 +62,11 @@ pub mod async_io { Ok(listener) } - pub fn join_multicast_v6( + pub async fn join_multicast_v6( &mut self, multiaddr: Ipv6Addr, interface: u32, - ) -> Result<(), Error> - where - D: NetworkStackMulticastDriver + 'static, - { + ) -> Result<(), Error> { self.0.get_ref().join_multicast_v6(&multiaddr, interface)?; info!("Joined IPV6 multicast {}/{}", multiaddr, interface); @@ -79,14 +74,11 @@ pub mod async_io { Ok(()) } - pub fn join_multicast_v4( + pub async fn join_multicast_v4( &mut self, multiaddr: Ipv4Addr, interface: Ipv4Addr, - ) -> Result<(), Error> - where - D: NetworkStackMulticastDriver + 'static, - { + ) -> Result<(), Error> { #[cfg(not(target_os = "espidf"))] self.0.get_ref().join_multicast_v4(&multiaddr, &interface)?; @@ -181,9 +173,7 @@ pub mod embassy_net { use log::{debug, info, warn}; - use crate::transport::network::embassy_net_stack::{ - NetworkStack, NetworkStackDriver, NetworkStackMulticastDriver, - }; + use crate::transport::network::embassy_net_stack::{NetworkStack, NetworkStackDriver}; use crate::transport::network::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr}; const RX_BUF_SIZE: usize = 4096; @@ -210,11 +200,11 @@ pub mod embassy_net { pub struct UdpListener<'a, D>(UdpSocket<'a>, &'a NetworkStack) where - D: NetworkStackDriver; + D: NetworkStackDriver + 'static; impl<'a, D> UdpListener<'a, D> where - D: NetworkStackDriver + 'a, + D: NetworkStackDriver + 'a + 'static, { pub async fn new( stack: &'a NetworkStack, @@ -239,16 +229,14 @@ pub mod embassy_net { Ok(UdpListener(socket, stack)) } - pub fn join_multicast_v6( + pub async fn join_multicast_v6( &mut self, multiaddr: Ipv6Addr, _interface: u32, - ) -> Result<(), Error> - where - D: NetworkStackMulticastDriver + 'static, - { + ) -> Result<(), Error> { self.1 .join_multicast_group(Self::from_ip_addr(IpAddr::V6(multiaddr))) + .await .map_err(|e| { warn!("Error on the network: {:?}", e); ErrorCode::Network @@ -259,16 +247,14 @@ pub mod embassy_net { Ok(()) } - pub fn join_multicast_v4( + pub async fn join_multicast_v4( &mut self, multiaddr: Ipv4Addr, _interface: Ipv4Addr, - ) -> Result<(), Error> - where - D: NetworkStackMulticastDriver + 'static, - { + ) -> Result<(), Error> { self.1 .join_multicast_group(Self::from_ip_addr(IpAddr::V4(multiaddr))) + .await .map_err(|e| { warn!("Error on the network: {:?}", e); ErrorCode::Network From 263279e714f84199a0e5d81a22037db162903d78 Mon Sep 17 00:00:00 2001 From: ivmarkov Date: Tue, 18 Jul 2023 10:20:40 +0000 Subject: [PATCH 69/72] Make multicast ipv6 optional --- examples/onoff_light/src/main.rs | 3 +- matter/Cargo.toml | 15 +++---- matter/src/data_model/core.rs | 2 +- matter/src/mdns/astro.rs | 3 +- matter/src/mdns/builtin.rs | 70 +++++++++++++++++++------------- 5 files changed, 53 insertions(+), 40 deletions(-) diff --git a/examples/onoff_light/src/main.rs b/examples/onoff_light/src/main.rs index 1b882ac9..3f2d8906 100644 --- a/examples/onoff_light/src/main.rs +++ b/examples/onoff_light/src/main.rs @@ -101,8 +101,7 @@ fn run() -> Result<(), Error> { 0, "matter-demo", ipv4_addr.octets(), - Some(ipv6_addr.octets()), - interface, + Some((ipv6_addr.octets(), interface)), &dev_det, matter::MATTER_PORT, ); diff --git a/matter/Cargo.toml b/matter/Cargo.toml index 46e16c99..05fd6b88 100644 --- a/matter/Cargo.toml +++ b/matter/Cargo.toml @@ -1,7 +1,7 @@ [package] name = "matter-iot" version = "0.1.0" -edition = "2018" +edition = "2021" authors = ["Kedar Sovani "] description = "Native RUST implementation of the Matter (Smart-Home) ecosystem" repository = "https://github.com/kedars/matter-rs" @@ -17,7 +17,8 @@ path = "src/lib.rs" [features] default = ["os", "crypto_rustcrypto"] os = ["std", "backtrace", "env_logger", "nix", "critical-section/std", "embassy-sync/std", "embassy-time/std"] -std = ["alloc", "rand", "qrcode", "async-io", "esp-idf-sys/std"] +esp-idf = ["std", "crypto_rustcrypto", "esp-idf-sys", "esp-idf-hal", "esp-idf-svc"] +std = ["alloc", "rand", "qrcode", "async-io", "esp-idf-sys?/std", "embassy-time/generic-queue-16"] backtrace = [] alloc = [] nightly = [] @@ -43,10 +44,11 @@ owo-colors = "3" time = { version = "0.3", default-features = false } verhoeff = { version = "1", default-features = false } embassy-futures = "0.1" -embassy-time = { version = "0.1.1", features = ["generic-queue-8"] } +embassy-time = "0.1.1" embassy-sync = "0.2" critical-section = "1.1.1" domain = { version = "0.7.2", default_features = false, features = ["heapless"] } +portable-atomic = "1" # embassy-net dependencies embassy-net = { version = "0.1", features = ["udp", "igmp", "proto-ipv6", "medium-ethernet", "medium-ip"], optional = true } @@ -84,10 +86,9 @@ env_logger = { version = "0.10.0", optional = true } nix = { version = "0.26", features = ["net"], optional = true } [target.'cfg(target_os = "espidf")'.dependencies] -esp-idf-sys = { version = "0.33", default-features = false, features = ["native", "binstart"] } -esp-idf-hal = { version = "0.41", features = ["embassy-sync", "critical-section"] } -esp-idf-svc = { version = "0.46", features = ["embassy-time-driver"] } -embedded-svc = "0.25" +esp-idf-sys = { version = "0.33", optional = true, default-features = false, features = ["native", "binstart"] } +esp-idf-hal = { version = "0.41", optional = true, features = ["embassy-sync", "critical-section"] } # TODO: Only necessary for the examples +esp-idf-svc = { version = "0.46", optional = true, features = ["embassy-time-driver"] } # TODO: Only necessary for the examples [build-dependencies] embuild = "0.31.2" diff --git a/matter/src/data_model/core.rs b/matter/src/data_model/core.rs index 69935c5c..0a4e99a8 100644 --- a/matter/src/data_model/core.rs +++ b/matter/src/data_model/core.rs @@ -15,7 +15,7 @@ * limitations under the License. */ -use core::sync::atomic::{AtomicU32, Ordering}; +use portable_atomic::{AtomicU32, Ordering}; use super::objects::*; use crate::{ diff --git a/matter/src/mdns/astro.rs b/matter/src/mdns/astro.rs index 857fb467..1ac4331e 100644 --- a/matter/src/mdns/astro.rs +++ b/matter/src/mdns/astro.rs @@ -23,8 +23,7 @@ impl<'a> MdnsService<'a> { _id: u16, _hostname: &str, _ip: [u8; 4], - _ipv6: Option<[u8; 16]>, - _interface: u32, + _ipv6: Option<([u8; 16], u32)>, dev_det: &'a BasicInfoConfig<'a>, matter_port: u16, ) -> Self { diff --git a/matter/src/mdns/builtin.rs b/matter/src/mdns/builtin.rs index c3049a51..9845b20f 100644 --- a/matter/src/mdns/builtin.rs +++ b/matter/src/mdns/builtin.rs @@ -25,7 +25,7 @@ const PORT: u16 = 5353; pub struct MdnsService<'a> { host: Host<'a>, #[allow(unused)] - interface: u32, + interface: Option, dev_det: &'a BasicInfoConfig<'a>, matter_port: u16, services: RefCell, ServiceMode), 4>>, @@ -38,8 +38,7 @@ impl<'a> MdnsService<'a> { id: u16, hostname: &'a str, ip: [u8; 4], - ipv6: Option<[u8; 16]>, - interface: u32, + ipv6: Option<([u8; 16], u32)>, dev_det: &'a BasicInfoConfig<'a>, matter_port: u16, ) -> Self { @@ -48,9 +47,17 @@ impl<'a> MdnsService<'a> { id, hostname, ip, - ipv6, + ipv6: if let Some((ipv6, _)) = ipv6 { + Some(ipv6) + } else { + None + }, + }, + interface: if let Some((_, interface)) = ipv6 { + Some(interface) + } else { + None }, - interface, dev_det, matter_port, services: RefCell::new(heapless::Vec::new()), @@ -137,8 +144,13 @@ impl<'a> MdnsRunner<'a> { ) .await?; - udp.join_multicast_v6(IPV6_BROADCAST_ADDR, self.0.interface) - .await?; + // V6 multicast does not work with smoltcp yet (see https://github.com/smoltcp-rs/smoltcp/pull/602) + #[cfg(not(feature = "embassy-net"))] + if let Some(interface) = self.0.interface { + udp.join_multicast_v6(IPV6_BROADCAST_ADDR, interface) + .await?; + } + udp.join_multicast_v4( IP_BROADCAST_ADDR, crate::transport::network::Ipv4Addr::from(self.0.host.ip), @@ -217,35 +229,37 @@ impl<'a> MdnsRunner<'a> { IpAddr::V4(IP_BROADCAST_ADDR), IpAddr::V6(IPV6_BROADCAST_ADDR), ] { - loop { - let sent = { - let mut data = tx_pipe.data.lock().await; + if self.0.interface.is_some() || addr == IpAddr::V4(IP_BROADCAST_ADDR) { + loop { + let sent = { + let mut data = tx_pipe.data.lock().await; - if data.chunk.is_none() { - let len = self.0.host.broadcast(&self.0, data.buf, 60)?; + if data.chunk.is_none() { + let len = self.0.host.broadcast(&self.0, data.buf, 60)?; - if len > 0 { - info!("Broadasting mDNS entry to {}:{}", addr, PORT); + if len > 0 { + info!("Broadasting mDNS entry to {}:{}", addr, PORT); + + data.chunk = Some(Chunk { + start: 0, + end: len, + addr: Address::Udp(SocketAddr::new(addr, PORT)), + }); - data.chunk = Some(Chunk { - start: 0, - end: len, - addr: Address::Udp(SocketAddr::new(addr, PORT)), - }); + tx_pipe.data_supplied_notification.signal(()); + } - tx_pipe.data_supplied_notification.signal(()); + true + } else { + false } + }; - true + if sent { + break; } else { - false + tx_pipe.data_consumed_notification.wait().await; } - }; - - if sent { - break; - } else { - tx_pipe.data_consumed_notification.wait().await; } } } From 71b9a578d01d6f4c4aa59df20e57cdd39ad408ee Mon Sep 17 00:00:00 2001 From: ivmarkov Date: Wed, 19 Jul 2023 12:22:26 +0000 Subject: [PATCH 70/72] Remove embassy-net features that matter-rs is not using --- matter/Cargo.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/matter/Cargo.toml b/matter/Cargo.toml index 05fd6b88..69ad4e83 100644 --- a/matter/Cargo.toml +++ b/matter/Cargo.toml @@ -51,7 +51,7 @@ domain = { version = "0.7.2", default_features = false, features = ["heapless"] portable-atomic = "1" # embassy-net dependencies -embassy-net = { version = "0.1", features = ["udp", "igmp", "proto-ipv6", "medium-ethernet", "medium-ip"], optional = true } +embassy-net = { version = "0.1", features = ["igmp", "proto-ipv6", "udp"], optional = true } embassy-net-driver = { version = "0.1", optional = true } smoltcp = { version = "0.10", default-features = false, optional = true } From 916f2148f89b0a68ea6d7137e2d1ab010377c375 Mon Sep 17 00:00:00 2001 From: ivmarkov Date: Fri, 21 Jul 2023 12:06:58 +0000 Subject: [PATCH 71/72] Simplify API by combining Matter, Transport and TransportRunner; simplify Mdns and Psm runners --- .github/workflows/test-linux.yml | 6 +- Cargo.toml | 3 +- README.md | 13 +- examples/onoff_light/src/main.rs | 99 +++---- matter/Cargo.toml | 28 +- matter/src/core.rs | 36 ++- matter/src/crypto/crypto_mbedtls.rs | 8 + matter/src/crypto/crypto_openssl.rs | 8 + matter/src/crypto/mod.rs | 28 +- matter/src/data_model/objects/metadata.rs | 17 ++ matter/src/error.rs | 6 +- matter/src/interaction_model/core.rs | 4 +- matter/src/mdns.rs | 8 +- matter/src/mdns/astro.rs | 44 ++- matter/src/mdns/builtin.rs | 68 ++--- matter/src/pairing/mod.rs | 2 +- matter/src/persist.rs | 54 +++- matter/src/secure_channel/crypto.rs | 14 +- matter/src/secure_channel/mod.rs | 14 +- matter/src/transport/core.rs | 268 +++++++++++++++--- matter/src/transport/exchange.rs | 56 ++-- matter/src/transport/mod.rs | 1 - matter/src/transport/runner.rs | 319 ---------------------- matter/src/transport/session.rs | 5 + matter/src/transport/udp.rs | 8 +- matter/tests/common/im_engine.rs | 22 +- matter/tests/data_model/timed_requests.rs | 4 +- matter_macro_derive/Cargo.toml | 1 + tools/tlv_tool/Cargo.toml | 3 +- 29 files changed, 519 insertions(+), 628 deletions(-) delete mode 100644 matter/src/transport/runner.rs diff --git a/.github/workflows/test-linux.yml b/.github/workflows/test-linux.yml index 82e24255..e08c84b7 100644 --- a/.github/workflows/test-linux.yml +++ b/.github/workflows/test-linux.yml @@ -15,11 +15,11 @@ jobs: runs-on: ubuntu-latest strategy: matrix: - crypto-backend: ['crypto_openssl', 'crypto_rustcrypto', 'crypto_mbedtls'] + crypto-backend: ['rustcrypto', 'mbedtls', 'openssl'] steps: - uses: actions/checkout@v2 - name: Build - run: cd matter; cargo build --verbose --no-default-features --features ${{matrix.crypto-backend}} + run: cd matter; cargo build --no-default-features --features ${{matrix.crypto-backend}} - name: Run tests - run: cd matter; cargo test --verbose --no-default-features --features ${{matrix.crypto-backend}} -- --test-threads=1 + run: cd matter; cargo test --no-default-features --features os,${{matrix.crypto-backend}} -- --test-threads=1 diff --git a/Cargo.toml b/Cargo.toml index 0ec37cec..6c6d58c7 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,7 +1,8 @@ [workspace] +resolver = "2" members = ["matter", "matter_macro_derive"] -exclude = ["examples/*"] +exclude = ["examples/*", "tools/tlv_tool"] # For compatibility with ESP IDF [patch.crates-io] diff --git a/README.md b/README.md index 86de8807..3de42e9a 100644 --- a/README.md +++ b/README.md @@ -41,23 +41,24 @@ The `async` metaphor however comes with a bit higher memory usage, due to not en ## Build -Building the library: +### Building the library ``` $ cargo build ``` -Building and running the example (Linux, MacOS X): +### Building and running the example (Linux, MacOS X) ``` $ cargo run --example onoff_light ``` -Building the example (Espressif's ESP-IDF): +### Building the example (Espressif's ESP-IDF) + * Install all build prerequisites described [here](https://github.com/esp-rs/esp-idf-template#prerequisites) * Build with the following command line: ``` -export MCU=esp32; export CARGO_TARGET_XTENSA_ESP32_ESPIDF_LINKER=ldproxy; export RUSTFLAGS="-C default-linker-libraries"; export WIFI_SSID=ssid;export WIFI_PASS=pass; cargo build --example onoff_light --no-default-features --features std,crypto_rustcrypto --target xtensa-esp32-espidf -Zbuild-std=std,panic_abort +export MCU=esp32; export CARGO_TARGET_XTENSA_ESP32_ESPIDF_LINKER=ldproxy; export RUSTFLAGS="-C default-linker-libraries"; export WIFI_SSID=ssid;export WIFI_PASS=pass; cargo build --example onoff_light --no-default-features --features esp-idf --target xtensa-esp32-espidf -Zbuild-std=std,panic_abort ``` * If you are building for a different Espressif MCU, change the `MCU` variable, the `xtensa-esp32-espidf` target and the name of the `CARGO_TARGET__LINKER` variable to match your MCU and its Rust target. Available Espressif MCUs and targets are: * esp32 / xtensa-esp32-espidf @@ -69,6 +70,10 @@ export MCU=esp32; export CARGO_TARGET_XTENSA_ESP32_ESPIDF_LINKER=ldproxy; export * Put in `WIFI_SSID` / `WIFI_PASS` the SSID & password for your wireless router * Flash using the `espflash` utility described in the build prerequsites' link above +### Building the example (ESP32-XX baremetal or RP2040) + +Coming soon! + ## Test With the `chip-tool` (the current tool for testing Matter) use the Ethernet commissioning mechanism: diff --git a/examples/onoff_light/src/main.rs b/examples/onoff_light/src/main.rs index 3f2d8906..e6de9b7e 100644 --- a/examples/onoff_light/src/main.rs +++ b/examples/onoff_light/src/main.rs @@ -18,6 +18,7 @@ use core::borrow::Borrow; use core::pin::pin; +use embassy_futures::select::select3; use log::info; use matter::core::{CommissioningData, Matter}; use matter::data_model::cluster_basic_information::BasicInfoConfig; @@ -27,18 +28,18 @@ use matter::data_model::objects::*; use matter::data_model::root_endpoint; use matter::data_model::system_model::descriptor; use matter::error::Error; -use matter::mdns::MdnsService; -use matter::persist::FilePsm; +use matter::mdns::{MdnsRunBuffers, MdnsService}; use matter::secure_channel::spake2p::VerifierData; +use matter::transport::core::RunBuffers; use matter::transport::network::{Ipv4Addr, Ipv6Addr, NetworkStack}; -use matter::transport::runner::{AllUdpBuffers, TransportRunner}; +use matter::utils::select::EitherUnwrap; mod dev_att; #[cfg(feature = "std")] fn main() -> Result<(), Error> { let thread = std::thread::Builder::new() - .stack_size(140 * 1024) + .stack_size(150 * 1024) .spawn(run) .unwrap(); @@ -56,11 +57,11 @@ fn run() -> Result<(), Error> { initialize_logger(); info!( - "Matter memory: mDNS={}, Matter={}, TransportRunner={}, UdpBuffers={}", + "Matter memory: mDNS={}, Matter={}, MdnsBuffers={}, RunBuffers={}", core::mem::size_of::(), core::mem::size_of::(), - core::mem::size_of::(), - core::mem::size_of::(), + core::mem::size_of::(), + core::mem::size_of::(), ); let dev_det = BasicInfoConfig { @@ -73,12 +74,6 @@ fn run() -> Result<(), Error> { device_name: "OnOff Light", }; - let psm_path = std::env::temp_dir().join("matter-iot"); - info!("Persisting from/to {}", psm_path.display()); - - #[cfg(all(feature = "std", not(target_os = "espidf")))] - let psm = matter::persist::FilePsm::new(psm_path)?; - let (ipv4_addr, ipv6_addr, interface) = initialize_network()?; let dev_att = dev_att::HardCodedDevAtt::new(); @@ -106,7 +101,7 @@ fn run() -> Result<(), Error> { matter::MATTER_PORT, ); - info!("mDNS initialized: {:p}", &mdns); + info!("mDNS initialized"); let matter = Matter::new( // vid/pid should match those in the DAC @@ -118,36 +113,28 @@ fn run() -> Result<(), Error> { matter::MATTER_PORT, ); - info!("Matter initialized: {:p}", &matter); + info!("Matter initialized"); #[cfg(all(feature = "std", not(target_os = "espidf")))] - { - let mut buf = [0; 4096]; - let buf = &mut buf; - if let Some(data) = psm.load("acls", buf)? { - matter.load_acls(data)?; - } - - if let Some(data) = psm.load("fabrics", buf)? { - matter.load_fabrics(data)?; - } - } - - let mut runner = TransportRunner::new(&matter); - - info!("Transport runner initialized: {:p}", &runner); + let mut psm = matter::persist::Psm::new(&matter, std::env::temp_dir().join("matter-iot"))?; let handler = HandlerCompat(handler(&matter)); - // NOTE (no_std): If using the `embassy-net` UDP implementation, replace this dummy stack with the `embassy-net` one - // When using a custom UDP stack, remove this + // When using a custom UDP stack, remove the network stack initialization below + // and call `Matter::run_piped()` instead, by utilizing the TX & RX `Pipe` structs + // to push/pull your UDP packets from/to the Matter stack. + // Ditto for `MdnsService`. + // + // When using the `embassy-net` feature (as opposed to the Rust Standard Library network stack), + // this initialization would be more complex. let stack = NetworkStack::new(); - let mut buffers = AllUdpBuffers::new(); + let mut mdns_buffers = MdnsRunBuffers::new(); + let mut mdns_runner = pin!(mdns.run(&stack, &mut mdns_buffers)); - let mut fut = pin!(runner.run_udp_all( + let mut buffers = RunBuffers::new(); + let mut runner = matter.run( &stack, - &mdns, &mut buffers, CommissioningData { // TODO: Hard-coded for now @@ -155,16 +142,30 @@ fn run() -> Result<(), Error> { discriminator: 250, }, &handler, - )); + ); + + info!( + "Matter transport runner memory: {}", + core::mem::size_of_val(&runner) + ); + + let mut runner = pin!(runner); + + #[cfg(all(feature = "std", not(target_os = "espidf")))] + let mut psm_runner = pin!(psm.run()); + + #[cfg(not(all(feature = "std", not(target_os = "espidf"))))] + let mut psm_runner = pin!(core::future::pending()); + + let mut runner = select3(&mut runner, &mut mdns_runner, &mut psm_runner); - // NOTE: For no_std, replace with your own no_std way of polling the future #[cfg(feature = "std")] - async_io::block_on(&mut fut)?; + async_io::block_on(&mut runner).unwrap()?; // NOTE (no_std): For no_std, replace with your own more efficient no_std executor, // because the executor used below is a simple busy-loop poller #[cfg(not(feature = "std"))] - embassy_futures::block_on(&mut fut)?; + embassy_futures::block_on(&mut runner).unwrap()?; Ok(()) } @@ -268,26 +269,6 @@ fn initialize_network() -> Result<(Ipv4Addr, Ipv6Addr, u32), Error> { Ok((ip, ipv6, 0 as _)) } -#[cfg(all(feature = "std", not(target_os = "espidf")))] -#[inline(never)] -async fn save(matter: &Matter<'_>, psm: &FilePsm) -> Result<(), Error> { - let mut buf = [0; 4096]; - let buf = &mut buf; - - loop { - matter.wait_changed().await; - if matter.is_changed() { - if let Some(data) = matter.store_acls(buf)? { - psm.store("acls", data)?; - } - - if let Some(data) = matter.store_fabrics(buf)? { - psm.store("fabrics", data)?; - } - } - } -} - #[cfg(target_os = "espidf")] #[inline(never)] fn initialize_logger() { diff --git a/matter/Cargo.toml b/matter/Cargo.toml index 69ad4e83..9b5c5cc5 100644 --- a/matter/Cargo.toml +++ b/matter/Cargo.toml @@ -8,23 +8,23 @@ repository = "https://github.com/kedars/matter-rs" readme = "README.md" keywords = ["matter", "smart", "smart-home", "IoT", "ESP32"] categories = ["embedded", "network-programming"] -license = "MIT" +license = "Apache-2.0" [lib] name = "matter" path = "src/lib.rs" [features] -default = ["os", "crypto_rustcrypto"] +default = ["os", "mbedtls"] os = ["std", "backtrace", "env_logger", "nix", "critical-section/std", "embassy-sync/std", "embassy-time/std"] -esp-idf = ["std", "crypto_rustcrypto", "esp-idf-sys", "esp-idf-hal", "esp-idf-svc"] +esp-idf = ["std", "rustcrypto", "esp-idf-sys"] std = ["alloc", "rand", "qrcode", "async-io", "esp-idf-sys?/std", "embassy-time/generic-queue-16"] backtrace = [] alloc = [] nightly = [] -crypto_openssl = ["alloc", "openssl", "foreign-types", "hmac", "sha2"] -crypto_mbedtls = ["alloc", "mbedtls"] -crypto_rustcrypto = ["alloc", "sha2", "hmac", "pbkdf2", "hkdf", "aes", "ccm", "p256", "elliptic-curve", "crypto-bigint", "x509-cert", "rand_core"] +openssl = ["alloc", "dep:openssl", "foreign-types", "hmac", "sha2"] +mbedtls = ["alloc", "dep:mbedtls"] +rustcrypto = ["alloc", "sha2", "hmac", "pbkdf2", "hkdf", "aes", "ccm", "p256", "elliptic-curve", "crypto-bigint", "x509-cert", "rand_core"] embassy-net = ["dep:embassy-net", "dep:embassy-net-driver", "smoltcp"] [dependencies] @@ -58,10 +58,10 @@ smoltcp = { version = "0.10", default-features = false, optional = true } # STD-only dependencies rand = { version = "0.8.5", optional = true } qrcode = { version = "0.12", default-features = false, optional = true } # Print QR code -async-io = { version = "=1.12", optional = true } # =1.2 for compatibility with ESP IDF +async-io = { version = "=1.12", optional = true } # =1.12 for compatibility with ESP IDF # crypto -openssl = { git = "https://github.com/sfackler/rust-openssl", optional = true } +openssl = { version = "0.10.55", optional = true } foreign-types = { version = "0.3.2", optional = true } # rust-crypto @@ -81,18 +81,22 @@ x509-cert = { version = "0.2.0", default-features = false, features = ["pem"], o astro-dnssd = { version = "0.3" } [target.'cfg(not(target_os = "espidf"))'.dependencies] -mbedtls = { git = "https://github.com/fortanix/rust-mbedtls", optional = true } +mbedtls = { version = "0.9", optional = true } env_logger = { version = "0.10.0", optional = true } nix = { version = "0.26", features = ["net"], optional = true } [target.'cfg(target_os = "espidf")'.dependencies] -esp-idf-sys = { version = "0.33", optional = true, default-features = false, features = ["native", "binstart"] } -esp-idf-hal = { version = "0.41", optional = true, features = ["embassy-sync", "critical-section"] } # TODO: Only necessary for the examples -esp-idf-svc = { version = "0.46", optional = true, features = ["embassy-time-driver"] } # TODO: Only necessary for the examples +esp-idf-sys = { version = "0.33", optional = true, default-features = false, features = ["native"] } [build-dependencies] embuild = "0.31.2" +[target.'cfg(target_os = "espidf")'.dev-dependencies] +esp-idf-sys = { version = "0.33", default-features = false, features = ["binstart"] } +esp-idf-hal = { version = "0.41", features = ["embassy-sync", "critical-section"] } +esp-idf-svc = { version = "0.46", features = ["embassy-time-driver"] } +embedded-svc = { version = "0.25" } + [[example]] name = "onoff_light" path = "../examples/onoff_light/src/main.rs" diff --git a/matter/src/core.rs b/matter/src/core.rs index 13c0930f..f0196526 100644 --- a/matter/src/core.rs +++ b/matter/src/core.rs @@ -28,8 +28,11 @@ use crate::{ mdns::Mdns, pairing::{print_pairing_code_and_qr, DiscoveryCapabilities}, secure_channel::{pake::PaseMgr, spake2p::VerifierData}, - transport::exchange::Notification, - utils::{epoch::Epoch, rand::Rand}, + transport::{ + exchange::{ExchangeCtx, MAX_EXCHANGES}, + session::SessionMgr, + }, + utils::{epoch::Epoch, rand::Rand, select::Notification}, }; /* The Matter Port */ @@ -45,17 +48,20 @@ pub struct CommissioningData { /// The primary Matter Object pub struct Matter<'a> { - pub fabric_mgr: RefCell, - pub acl_mgr: RefCell, - pub pase_mgr: RefCell, - pub failsafe: RefCell, - pub persist_notification: Notification, - pub mdns: &'a dyn Mdns, - pub epoch: Epoch, - pub rand: Rand, - pub dev_det: &'a BasicInfoConfig<'a>, - pub dev_att: &'a dyn DevAttDataFetcher, - pub port: u16, + fabric_mgr: RefCell, + pub acl_mgr: RefCell, // Public for tests + pase_mgr: RefCell, + failsafe: RefCell, + persist_notification: Notification, + pub(crate) send_notification: Notification, + mdns: &'a dyn Mdns, + pub(crate) epoch: Epoch, + pub(crate) rand: Rand, + dev_det: &'a BasicInfoConfig<'a>, + dev_att: &'a dyn DevAttDataFetcher, + pub(crate) port: u16, + pub(crate) exchanges: RefCell>, + pub session_mgr: RefCell, // Public for tests } impl<'a> Matter<'a> { @@ -94,12 +100,15 @@ impl<'a> Matter<'a> { pase_mgr: RefCell::new(PaseMgr::new(epoch, rand)), failsafe: RefCell::new(FailSafe::new()), persist_notification: Notification::new(), + send_notification: Notification::new(), mdns, epoch, rand, dev_det, dev_att, port, + exchanges: RefCell::new(heapless::Vec::new()), + session_mgr: RefCell::new(SessionMgr::new(epoch, rand)), } } @@ -160,6 +169,7 @@ impl<'a> Matter<'a> { Ok(false) } } + pub fn notify_changed(&self) { if self.is_changed() { self.persist_notification.signal(()); diff --git a/matter/src/crypto/crypto_mbedtls.rs b/matter/src/crypto/crypto_mbedtls.rs index 1eb7a884..7403e578 100644 --- a/matter/src/crypto/crypto_mbedtls.rs +++ b/matter/src/crypto/crypto_mbedtls.rs @@ -17,6 +17,8 @@ extern crate alloc; +use core::fmt::{self, Debug}; + use alloc::sync::Arc; use log::{error, info}; @@ -355,3 +357,9 @@ impl Sha256 { Ok(()) } } + +impl Debug for Sha256 { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> Result<(), fmt::Error> { + write!(f, "Sha256") + } +} diff --git a/matter/src/crypto/crypto_openssl.rs b/matter/src/crypto/crypto_openssl.rs index 24fa267f..a29df5c6 100644 --- a/matter/src/crypto/crypto_openssl.rs +++ b/matter/src/crypto/crypto_openssl.rs @@ -15,6 +15,8 @@ * limitations under the License. */ +use core::fmt::{self, Debug}; + use crate::error::{Error, ErrorCode}; use crate::utils::rand::Rand; @@ -391,3 +393,9 @@ impl Sha256 { Ok(()) } } + +impl Debug for Sha256 { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> Result<(), fmt::Error> { + write!(f, "Sha256") + } +} diff --git a/matter/src/crypto/mod.rs b/matter/src/crypto/mod.rs index 85c40b07..04584c8c 100644 --- a/matter/src/crypto/mod.rs +++ b/matter/src/crypto/mod.rs @@ -37,37 +37,29 @@ pub const ECDH_SHARED_SECRET_LEN_BYTES: usize = 32; pub const EC_SIGNATURE_LEN_BYTES: usize = 64; -#[cfg(all(feature = "crypto_mbedtls", target_os = "espidf"))] +#[cfg(all(feature = "mbedtls", target_os = "espidf"))] mod crypto_esp_mbedtls; -#[cfg(all(feature = "crypto_mbedtls", target_os = "espidf"))] +#[cfg(all(feature = "mbedtls", target_os = "espidf"))] pub use self::crypto_esp_mbedtls::*; -#[cfg(all(feature = "crypto_mbedtls", not(target_os = "espidf")))] +#[cfg(all(feature = "mbedtls", not(target_os = "espidf")))] mod crypto_mbedtls; -#[cfg(all(feature = "crypto_mbedtls", not(target_os = "espidf")))] +#[cfg(all(feature = "mbedtls", not(target_os = "espidf")))] pub use self::crypto_mbedtls::*; -#[cfg(feature = "crypto_openssl")] +#[cfg(feature = "openssl")] mod crypto_openssl; -#[cfg(feature = "crypto_openssl")] +#[cfg(feature = "openssl")] pub use self::crypto_openssl::*; -#[cfg(feature = "crypto_rustcrypto")] +#[cfg(feature = "rustcrypto")] mod crypto_rustcrypto; -#[cfg(feature = "crypto_rustcrypto")] +#[cfg(feature = "rustcrypto")] pub use self::crypto_rustcrypto::*; -#[cfg(not(any( - feature = "crypto_openssl", - feature = "crypto_mbedtls", - feature = "crypto_rustcrypto" -)))] +#[cfg(not(any(feature = "openssl", feature = "mbedtls", feature = "rustcrypto")))] pub mod crypto_dummy; -#[cfg(not(any( - feature = "crypto_openssl", - feature = "crypto_mbedtls", - feature = "crypto_rustcrypto" -)))] +#[cfg(not(any(feature = "openssl", feature = "mbedtls", feature = "rustcrypto")))] pub use self::crypto_dummy::*; impl<'a> FromTLV<'a> for KeyPair { diff --git a/matter/src/data_model/objects/metadata.rs b/matter/src/data_model/objects/metadata.rs index 368ff9b6..3e15612d 100644 --- a/matter/src/data_model/objects/metadata.rs +++ b/matter/src/data_model/objects/metadata.rs @@ -1,3 +1,20 @@ +/* + * + * Copyright (c) 2020-2022 Project CHIP Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + use crate::data_model::objects::Node; #[cfg(feature = "nightly")] diff --git a/matter/src/error.rs b/matter/src/error.rs index c8da8208..91ba77e4 100644 --- a/matter/src/error.rs +++ b/matter/src/error.rs @@ -149,7 +149,7 @@ impl From> for Error { } } -#[cfg(feature = "crypto_openssl")] +#[cfg(feature = "openssl")] impl From for Error { fn from(e: openssl::error::ErrorStack) -> Self { ::log::error!("Error in TLS: {}", e); @@ -157,7 +157,7 @@ impl From for Error { } } -#[cfg(all(feature = "crypto_mbedtls", not(target_os = "espidf")))] +#[cfg(all(feature = "mbedtls", not(target_os = "espidf")))] impl From for Error { fn from(e: mbedtls::Error) -> Self { ::log::error!("Error in TLS: {}", e); @@ -173,7 +173,7 @@ impl From for Error { } } -#[cfg(feature = "crypto_rustcrypto")] +#[cfg(feature = "rustcrypto")] impl From for Error { fn from(_e: ccm::aead::Error) -> Self { Self::new(ErrorCode::Crypto) diff --git a/matter/src/interaction_model/core.rs b/matter/src/interaction_model/core.rs index 4ce35837..b2d3b8c5 100644 --- a/matter/src/interaction_model/core.rs +++ b/matter/src/interaction_model/core.rs @@ -609,7 +609,7 @@ impl<'a, 'r, 'p> Interaction<'a, 'r, 'p> { rx: &mut Packet<'_>, tx: &mut Packet<'_>, ) -> Result, Error> { - let epoch = exchange.transport().matter().epoch; + let epoch = exchange.matter.epoch; let mut opcode: OpCode = rx.get_proto_opcode()?; @@ -641,7 +641,7 @@ impl<'a, 'r, 'p> Interaction<'a, 'r, 'p> { where S: FnOnce() -> u32, { - let epoch = exchange.transport().matter().epoch; + let epoch = exchange.matter.epoch; let opcode = rx.get_proto_opcode()?; let rx_data = rx.as_slice(); diff --git a/matter/src/mdns.rs b/matter/src/mdns.rs index 47a08614..c66d12f6 100644 --- a/matter/src/mdns.rs +++ b/matter/src/mdns.rs @@ -55,19 +55,15 @@ where } } -#[cfg(all(feature = "std", target_os = "macos"))] -pub use astro::MdnsRunner; #[cfg(all(feature = "std", target_os = "macos"))] pub use astro::MdnsService; #[cfg(all(feature = "std", target_os = "macos"))] pub use astro::MdnsUdpBuffers; -#[cfg(not(all(feature = "std", target_os = "macos")))] -pub use builtin::MdnsRunner; +#[cfg(any(feature = "std", feature = "embassy-net"))] +pub use builtin::MdnsRunBuffers; #[cfg(not(all(feature = "std", target_os = "macos")))] pub use builtin::MdnsService; -#[cfg(not(all(feature = "std", target_os = "macos")))] -pub use builtin::MdnsUdpBuffers; pub struct DummyMdns; diff --git a/matter/src/mdns/astro.rs b/matter/src/mdns/astro.rs index 1ac4331e..afb933d3 100644 --- a/matter/src/mdns/astro.rs +++ b/matter/src/mdns/astro.rs @@ -11,6 +11,17 @@ use log::info; use super::ServiceMode; +/// Only for API-compatibility with builtin::MdnsRunner +pub struct MdnsUdpBuffers(()); + +/// Only for API-compatibility with builtin::MdnsRunner +impl MdnsUdpBuffers { + #[inline(always)] + pub const fn new() -> Self { + Self(()) + } +} + pub struct MdnsService<'a> { dev_det: &'a BasicInfoConfig<'a>, matter_port: u16, @@ -78,16 +89,15 @@ impl<'a> MdnsService<'a> { Ok(()) } -} -/// Only for API-compatibility with builtin::MdnsRunner -pub struct MdnsUdpBuffers(()); + /// Only for API-compatibility with builtin::MdnsRunner + pub async fn run_udp(&mut self, buffers: &mut MdnsUdpBuffers) -> Result<(), Error> { + core::future::pending::>().await + } -/// Only for API-compatibility with builtin::MdnsRunner -impl MdnsUdpBuffers { - #[inline(always)] - pub const fn new() -> Self { - Self(()) + /// Only for API-compatibility with builtin::MdnsRunner + pub async fn run(&self, _tx_pipe: &Pipe<'_>, _rx_pipe: &Pipe<'_>) -> Result<(), Error> { + core::future::pending::>().await } } @@ -100,21 +110,3 @@ impl<'a> super::Mdns for MdnsService<'a> { MdnsService::remove(self, service) } } - -/// Only for API-compatibility with builtin::MdnsRunner -pub struct MdnsRunner<'a>(&'a MdnsService<'a>); - -/// Only for API-compatibility with builtin::MdnsRunner -impl<'a> MdnsRunner<'a> { - pub const fn new(mdns: &'a MdnsService<'a>) -> Self { - Self(mdns) - } - - pub async fn run_udp(&mut self, buffers: &mut MdnsUdpBuffers) -> Result<(), Error> { - core::future::pending::>().await - } - - pub async fn run(&self, _tx_pipe: &Pipe<'_>, _rx_pipe: &Pipe<'_>) -> Result<(), Error> { - core::future::pending::>().await - } -} diff --git a/matter/src/mdns/builtin.rs b/matter/src/mdns/builtin.rs index 9845b20f..e799879a 100644 --- a/matter/src/mdns/builtin.rs +++ b/matter/src/mdns/builtin.rs @@ -22,6 +22,25 @@ const IPV6_BROADCAST_ADDR: Ipv6Addr = Ipv6Addr::new(0xff02, 0, 0, 0, 0, 0, 0, 0x const PORT: u16 = 5353; +#[cfg(any(feature = "std", feature = "embassy-net"))] +pub struct MdnsRunBuffers { + udp: crate::transport::udp::UdpBuffers, + tx_buf: core::mem::MaybeUninit<[u8; crate::transport::packet::MAX_TX_BUF_SIZE]>, + rx_buf: core::mem::MaybeUninit<[u8; crate::transport::packet::MAX_RX_BUF_SIZE]>, +} + +#[cfg(any(feature = "std", feature = "embassy-net"))] +impl MdnsRunBuffers { + #[inline(always)] + pub const fn new() -> Self { + Self { + udp: crate::transport::udp::UdpBuffers::new(), + tx_buf: core::mem::MaybeUninit::uninit(), + rx_buf: core::mem::MaybeUninit::uninit(), + } + } +} + pub struct MdnsService<'a> { host: Host<'a>, #[allow(unused)] @@ -100,39 +119,12 @@ impl<'a> MdnsService<'a> { Ok(()) } -} - -#[cfg(any(feature = "std", feature = "embassy-net"))] -pub struct MdnsUdpBuffers { - udp: crate::transport::udp::UdpBuffers, - tx_buf: core::mem::MaybeUninit<[u8; crate::transport::packet::MAX_TX_BUF_SIZE]>, - rx_buf: core::mem::MaybeUninit<[u8; crate::transport::packet::MAX_RX_BUF_SIZE]>, -} - -#[cfg(any(feature = "std", feature = "embassy-net"))] -impl MdnsUdpBuffers { - #[inline(always)] - pub const fn new() -> Self { - Self { - udp: crate::transport::udp::UdpBuffers::new(), - tx_buf: core::mem::MaybeUninit::uninit(), - rx_buf: core::mem::MaybeUninit::uninit(), - } - } -} - -pub struct MdnsRunner<'a>(&'a MdnsService<'a>); - -impl<'a> MdnsRunner<'a> { - pub const fn new(mdns: &'a MdnsService<'a>) -> Self { - Self(mdns) - } #[cfg(any(feature = "std", feature = "embassy-net"))] - pub async fn run_udp( - &mut self, + pub async fn run( + &self, stack: &crate::transport::network::NetworkStack, - buffers: &mut MdnsUdpBuffers, + buffers: &mut MdnsRunBuffers, ) -> Result<(), Error> where D: crate::transport::network::NetworkStackDriver, @@ -146,14 +138,14 @@ impl<'a> MdnsRunner<'a> { // V6 multicast does not work with smoltcp yet (see https://github.com/smoltcp-rs/smoltcp/pull/602) #[cfg(not(feature = "embassy-net"))] - if let Some(interface) = self.0.interface { + if let Some(interface) = self.interface { udp.join_multicast_v6(IPV6_BROADCAST_ADDR, interface) .await?; } udp.join_multicast_v4( IP_BROADCAST_ADDR, - crate::transport::network::Ipv4Addr::from(self.0.host.ip), + crate::transport::network::Ipv4Addr::from(self.host.ip), ) .await?; @@ -202,14 +194,14 @@ impl<'a> MdnsRunner<'a> { } }); - let mut run = pin!(async move { self.run(tx_pipe, rx_pipe).await }); + let mut run = pin!(async move { self.run_piped(tx_pipe, rx_pipe).await }); embassy_futures::select::select3(&mut tx, &mut rx, &mut run) .await .unwrap() } - pub async fn run(&self, tx_pipe: &Pipe<'_>, rx_pipe: &Pipe<'_>) -> Result<(), Error> { + pub async fn run_piped(&self, tx_pipe: &Pipe<'_>, rx_pipe: &Pipe<'_>) -> Result<(), Error> { let mut broadcast = pin!(self.broadcast(tx_pipe)); let mut respond = pin!(self.respond(rx_pipe, tx_pipe)); @@ -220,7 +212,7 @@ impl<'a> MdnsRunner<'a> { async fn broadcast(&self, tx_pipe: &Pipe<'_>) -> Result<(), Error> { loop { select( - self.0.notification.wait(), + self.notification.wait(), Timer::after(Duration::from_secs(30)), ) .await; @@ -229,13 +221,13 @@ impl<'a> MdnsRunner<'a> { IpAddr::V4(IP_BROADCAST_ADDR), IpAddr::V6(IPV6_BROADCAST_ADDR), ] { - if self.0.interface.is_some() || addr == IpAddr::V4(IP_BROADCAST_ADDR) { + if self.interface.is_some() || addr == IpAddr::V4(IP_BROADCAST_ADDR) { loop { let sent = { let mut data = tx_pipe.data.lock().await; if data.chunk.is_none() { - let len = self.0.host.broadcast(&self.0, data.buf, 60)?; + let len = self.host.broadcast(self, data.buf, 60)?; if len > 0 { info!("Broadasting mDNS entry to {}:{}", addr, PORT); @@ -280,7 +272,7 @@ impl<'a> MdnsRunner<'a> { let mut tx_data = tx_pipe.data.lock().await; if tx_data.chunk.is_none() { - let len = self.0.host.respond(&self.0, data, tx_data.buf, 60)?; + let len = self.host.respond(self, data, tx_data.buf, 60)?; if len > 0 { info!("Replying to mDNS query from {}", rx_chunk.addr); diff --git a/matter/src/pairing/mod.rs b/matter/src/pairing/mod.rs index 253062e0..f5cb05d1 100644 --- a/matter/src/pairing/mod.rs +++ b/matter/src/pairing/mod.rs @@ -96,7 +96,7 @@ pub fn print_pairing_code_and_qr( Ok(()) } -pub(self) fn passwd_from_comm_data(comm_data: &CommissioningData) -> u32 { +fn passwd_from_comm_data(comm_data: &CommissioningData) -> u32 { // todo: should this be part of the comm_data implementation? match comm_data.verifier.data { VerifierOption::Password(pwd) => pwd, diff --git a/matter/src/persist.rs b/matter/src/persist.rs index d9a27330..a25b13a0 100644 --- a/matter/src/persist.rs +++ b/matter/src/persist.rs @@ -15,31 +15,63 @@ * limitations under the License. */ #[cfg(feature = "std")] -pub use file_psm::*; +pub use fileio::*; #[cfg(feature = "std")] -mod file_psm { +pub mod fileio { use std::fs; use std::io::{Read, Write}; - use std::path::PathBuf; + use std::path::{Path, PathBuf}; use log::info; use crate::error::{Error, ErrorCode}; + use crate::Matter; - pub struct FilePsm { + pub struct Psm<'a> { + matter: &'a Matter<'a>, dir: PathBuf, + buf: [u8; 4096], } - impl FilePsm { - pub fn new(dir: PathBuf) -> Result { + impl<'a> Psm<'a> { + #[inline(always)] + pub fn new(matter: &'a Matter<'a>, dir: PathBuf) -> Result { fs::create_dir_all(&dir)?; - Ok(Self { dir }) + info!("Persisting from/to {}", dir.display()); + + let mut buf = [0; 4096]; + + if let Some(data) = Self::load(&dir, "acls", &mut buf)? { + matter.load_acls(data)?; + } + + if let Some(data) = Self::load(&dir, "fabrics", &mut buf)? { + matter.load_fabrics(data)?; + } + + Ok(Self { matter, dir, buf }) + } + + pub async fn run(&mut self) -> Result<(), Error> { + loop { + self.matter.wait_changed().await; + + if self.matter.is_changed() { + if let Some(data) = self.matter.store_acls(&mut self.buf)? { + Self::store(&self.dir, "acls", data)?; + } + + if let Some(data) = self.matter.store_fabrics(&mut self.buf)? { + Self::store(&self.dir, "fabrics", data)?; + } + } + } } - pub fn load<'a>(&self, key: &str, buf: &'a mut [u8]) -> Result, Error> { - let path = self.dir.join(key); + fn load<'b>(dir: &Path, key: &str, buf: &'b mut [u8]) -> Result, Error> { + let path = dir.join(key); match fs::File::open(path) { Ok(mut file) => { @@ -69,8 +101,8 @@ mod file_psm { } } - pub fn store(&self, key: &str, data: &[u8]) -> Result<(), Error> { - let path = self.dir.join(key); + fn store(dir: &Path, key: &str, data: &[u8]) -> Result<(), Error> { + let path = dir.join(key); let mut file = fs::File::create(path)?; diff --git a/matter/src/secure_channel/crypto.rs b/matter/src/secure_channel/crypto.rs index 027db690..8fa9ec53 100644 --- a/matter/src/secure_channel/crypto.rs +++ b/matter/src/secure_channel/crypto.rs @@ -15,17 +15,13 @@ * limitations under the License. */ -#[cfg(not(any( - feature = "crypto_openssl", - feature = "crypto_mbedtls", - feature = "crypto_rustcrypto" -)))] +#[cfg(not(any(feature = "openssl", feature = "mbedtls", feature = "rustcrypto")))] pub use super::crypto_dummy::CryptoSpake2; -#[cfg(all(feature = "crypto_mbedtls", target_os = "espidf"))] +#[cfg(all(feature = "mbedtls", target_os = "espidf"))] pub use super::crypto_esp_mbedtls::CryptoSpake2; -#[cfg(all(feature = "crypto_mbedtls", not(target_os = "espidf")))] +#[cfg(all(feature = "mbedtls", not(target_os = "espidf")))] pub use super::crypto_mbedtls::CryptoSpake2; -#[cfg(feature = "crypto_openssl")] +#[cfg(feature = "openssl")] pub use super::crypto_openssl::CryptoSpake2; -#[cfg(feature = "crypto_rustcrypto")] +#[cfg(feature = "rustcrypto")] pub use super::crypto_rustcrypto::CryptoSpake2; diff --git a/matter/src/secure_channel/mod.rs b/matter/src/secure_channel/mod.rs index 58020b44..9b538b60 100644 --- a/matter/src/secure_channel/mod.rs +++ b/matter/src/secure_channel/mod.rs @@ -17,19 +17,15 @@ pub mod case; pub mod common; -#[cfg(not(any( - feature = "crypto_openssl", - feature = "crypto_mbedtls", - feature = "crypto_rustcrypto" -)))] +#[cfg(not(any(feature = "openssl", feature = "mbedtls", feature = "rustcrypto")))] mod crypto_dummy; -#[cfg(all(feature = "crypto_mbedtls", target_os = "espidf"))] +#[cfg(all(feature = "mbedtls", target_os = "espidf"))] mod crypto_esp_mbedtls; -#[cfg(all(feature = "crypto_mbedtls", not(target_os = "espidf")))] +#[cfg(all(feature = "mbedtls", not(target_os = "espidf")))] mod crypto_mbedtls; -#[cfg(feature = "crypto_openssl")] +#[cfg(feature = "openssl")] pub mod crypto_openssl; -#[cfg(feature = "crypto_rustcrypto")] +#[cfg(feature = "rustcrypto")] pub mod crypto_rustcrypto; pub mod core; diff --git a/matter/src/transport/core.rs b/matter/src/transport/core.rs index 1a51b916..0874736b 100644 --- a/matter/src/transport/core.rs +++ b/matter/src/transport/core.rs @@ -15,14 +15,18 @@ * limitations under the License. */ -use core::{borrow::Borrow, cell::RefCell}; +use core::borrow::Borrow; +use core::mem::MaybeUninit; +use core::pin::pin; -use embassy_futures::select::select; +use embassy_futures::select::{select, select_slice, Either}; use embassy_sync::{blocking_mutex::raw::NoopRawMutex, channel::Channel}; use embassy_time::{Duration, Timer}; use log::{error, info, warn}; +use crate::utils::select::Notification; +use crate::CommissioningData; use crate::{ alloc, data_model::{core::DataModel, objects::DataModelHandler}, @@ -33,18 +37,17 @@ use crate::{ core::SecureChannel, }, transport::packet::Packet, + utils::select::EitherUnwrap, Matter, }; use super::{ exchange::{ - Exchange, ExchangeCtr, ExchangeCtx, ExchangeId, ExchangeState, Notification, Role, - MAX_EXCHANGES, + Exchange, ExchangeCtr, ExchangeCtx, ExchangeId, ExchangeState, Role, MAX_EXCHANGES, }, mrp::ReliableMessage, packet::{MAX_RX_BUF_SIZE, MAX_RX_STATUS_BUF_SIZE, MAX_TX_BUF_SIZE}, pipe::{Chunk, Pipe}, - session::SessionMgr, }; #[derive(Debug)] @@ -66,33 +69,216 @@ impl From for OpCodeDescriptor { } } -pub struct Transport<'a> { - matter: &'a Matter<'a>, - pub(crate) exchanges: RefCell>, - pub(crate) send_notification: Notification, - pub session_mgr: RefCell, +type TxBuf = MaybeUninit<[u8; MAX_TX_BUF_SIZE]>; +type RxBuf = MaybeUninit<[u8; MAX_RX_BUF_SIZE]>; +type SxBuf = MaybeUninit<[u8; MAX_RX_STATUS_BUF_SIZE]>; + +#[cfg(any(feature = "std", feature = "embassy-net"))] +pub struct RunBuffers { + udp_bufs: crate::transport::udp::UdpBuffers, + run_bufs: PacketBuffers, + tx_buf: TxBuf, + rx_buf: RxBuf, } -impl<'a> Transport<'a> { +#[cfg(any(feature = "std", feature = "embassy-net"))] +impl RunBuffers { #[inline(always)] - pub fn new(matter: &'a Matter<'a>) -> Self { - let epoch = matter.epoch; - let rand = matter.rand; + pub const fn new() -> Self { + Self { + udp_bufs: crate::transport::udp::UdpBuffers::new(), + run_bufs: PacketBuffers::new(), + tx_buf: core::mem::MaybeUninit::uninit(), + rx_buf: core::mem::MaybeUninit::uninit(), + } + } +} + +pub struct PacketBuffers { + tx: [TxBuf; MAX_EXCHANGES], + rx: [RxBuf; MAX_EXCHANGES], + sx: [SxBuf; MAX_EXCHANGES], +} + +impl PacketBuffers { + const TX_ELEM: TxBuf = MaybeUninit::uninit(); + const RX_ELEM: RxBuf = MaybeUninit::uninit(); + const SX_ELEM: SxBuf = MaybeUninit::uninit(); + + const TX_INIT: [TxBuf; MAX_EXCHANGES] = [Self::TX_ELEM; MAX_EXCHANGES]; + const RX_INIT: [RxBuf; MAX_EXCHANGES] = [Self::RX_ELEM; MAX_EXCHANGES]; + const SX_INIT: [SxBuf; MAX_EXCHANGES] = [Self::SX_ELEM; MAX_EXCHANGES]; + #[inline(always)] + pub const fn new() -> Self { Self { - matter, - exchanges: RefCell::new(heapless::Vec::new()), - send_notification: Notification::new(), - session_mgr: RefCell::new(SessionMgr::new(epoch, rand)), + tx: Self::TX_INIT, + rx: Self::RX_INIT, + sx: Self::SX_INIT, } } +} + +impl<'a> Matter<'a> { + #[cfg(any(feature = "std", feature = "embassy-net"))] + pub async fn run( + &self, + stack: &crate::transport::network::NetworkStack, + buffers: &mut RunBuffers, + dev_comm: CommissioningData, + handler: &H, + ) -> Result<(), Error> + where + D: crate::transport::network::NetworkStackDriver, + H: DataModelHandler, + { + let udp = crate::transport::udp::UdpListener::new( + stack, + crate::transport::network::SocketAddr::new( + crate::transport::network::IpAddr::V6( + crate::transport::network::Ipv6Addr::UNSPECIFIED, + ), + self.port, + ), + &mut buffers.udp_bufs, + ) + .await?; + + let tx_pipe = Pipe::new(unsafe { buffers.tx_buf.assume_init_mut() }); + let rx_pipe = Pipe::new(unsafe { buffers.rx_buf.assume_init_mut() }); + + let tx_pipe = &tx_pipe; + let rx_pipe = &rx_pipe; + let udp = &udp; + let run_bufs = &mut buffers.run_bufs; + + let mut tx = pin!(async move { + loop { + { + let mut data = tx_pipe.data.lock().await; + + if let Some(chunk) = data.chunk { + udp.send(chunk.addr.unwrap_udp(), &data.buf[chunk.start..chunk.end]) + .await?; + data.chunk = None; + tx_pipe.data_consumed_notification.signal(()); + } + } + + tx_pipe.data_supplied_notification.wait().await; + } + }); + + let mut rx = pin!(async move { + loop { + { + let mut data = rx_pipe.data.lock().await; + + if data.chunk.is_none() { + let (len, addr) = udp.recv(data.buf).await?; + + data.chunk = Some(Chunk { + start: 0, + end: len, + addr: crate::transport::network::Address::Udp(addr), + }); + rx_pipe.data_supplied_notification.signal(()); + } + } + + rx_pipe.data_consumed_notification.wait().await; + } + }); + + let mut run = pin!(async move { + self.run_piped(run_bufs, tx_pipe, rx_pipe, dev_comm, handler) + .await + }); + + embassy_futures::select::select3(&mut tx, &mut rx, &mut run) + .await + .unwrap() + } + + pub async fn run_piped( + &self, + buffers: &mut PacketBuffers, + tx_pipe: &Pipe<'_>, + rx_pipe: &Pipe<'_>, + dev_comm: CommissioningData, + handler: &H, + ) -> Result<(), Error> + where + H: DataModelHandler, + { + info!("Running Matter transport"); + + let buf = unsafe { buffers.rx[0].assume_init_mut() }; + + if self.start_comissioning(dev_comm, buf)? { + info!("Comissioning started"); + } - pub fn matter(&self) -> &'a Matter<'a> { - self.matter + let construction_notification = Notification::new(); + + let mut rx = pin!(self.handle_rx(buffers, rx_pipe, &construction_notification, handler)); + let mut tx = pin!(self.handle_tx(tx_pipe)); + + select(&mut rx, &mut tx).await.unwrap() } - pub async fn initiate(&self, _fabric_id: u64, _node_id: u64) -> Result, Error> { - unimplemented!() + #[inline(always)] + async fn handle_rx( + &self, + buffers: &mut PacketBuffers, + rx_pipe: &Pipe<'_>, + construction_notification: &Notification, + handler: &H, + ) -> Result<(), Error> + where + H: DataModelHandler, + { + info!("Creating queue for {} exchanges", 1); + + let channel = Channel::::new(); + + info!("Creating {} handlers", MAX_EXCHANGES); + let mut handlers = heapless::Vec::<_, MAX_EXCHANGES>::new(); + + info!("Handlers size: {}", core::mem::size_of_val(&handlers)); + + // Unsafely allow mutable aliasing in the packet pools by different indices + let pools: *mut PacketBuffers = buffers; + + for index in 0..MAX_EXCHANGES { + let channel = &channel; + let handler_id = index; + + let pools = unsafe { pools.as_mut() }.unwrap(); + + let tx_buf = unsafe { pools.tx[handler_id].assume_init_mut() }; + let rx_buf = unsafe { pools.rx[handler_id].assume_init_mut() }; + let sx_buf = unsafe { pools.sx[handler_id].assume_init_mut() }; + + handlers + .push(self.exchange_handler(tx_buf, rx_buf, sx_buf, handler_id, channel, handler)) + .map_err(|_| ()) + .unwrap(); + } + + let mut rx = pin!(self.handle_rx_multiplex(rx_pipe, construction_notification, &channel)); + + let result = select(&mut rx, select_slice(&mut handlers)).await; + + if let Either::First(result) = result { + if let Err(e) = &result { + error!("Exitting RX loop due to an error: {:?}", e); + } + + result?; + } + + Ok(()) } #[inline(always)] @@ -230,11 +416,11 @@ impl<'a> Transport<'a> { match rx.get_proto_id() { PROTO_ID_SECURE_CHANNEL => { - let sc = SecureChannel::new(self.matter()); + let sc = SecureChannel::new(self); sc.handle(&mut exchange, &mut rx, &mut tx).await?; - self.matter().notify_changed(); + self.notify_changed(); } PROTO_ID_INTERACTION_MODEL => { let dm = DataModel::new(handler); @@ -244,7 +430,7 @@ impl<'a> Transport<'a> { dm.handle(&mut exchange, &mut rx, &mut tx, &mut rx_status) .await?; - self.matter().notify_changed(); + self.notify_changed(); } other => { error!("Unknown Proto-ID: {}", other); @@ -254,6 +440,11 @@ impl<'a> Transport<'a> { Ok(()) } + pub fn reset_transport(&self) { + self.exchanges.borrow_mut().clear(); + self.session_mgr.borrow_mut().reset(); + } + pub fn process_rx<'r>( &'r self, construction_notification: &'r Notification, @@ -297,7 +488,7 @@ impl<'a> Transport<'a> { } } - self.matter().notify_changed(); + self.notify_changed(); } } @@ -305,13 +496,13 @@ impl<'a> Transport<'a> { let constructor = ExchangeCtr { exchange: Exchange { id: ctx.id.clone(), - transport: self, + matter: self, notification: Notification::new(), }, construction_notification, }; - self.matter().notify_changed(); + self.notify_changed(); Ok(Some(constructor)) } else if src_rx.proto.proto_id == PROTO_ID_SECURE_CHANNEL @@ -338,7 +529,7 @@ impl<'a> Transport<'a> { } } - self.matter().notify_changed(); + self.notify_changed(); Ok(None) } @@ -354,7 +545,7 @@ impl<'a> Transport<'a> { let mut exchanges = self.exchanges.borrow_mut(); - let ctx = Self::get(&mut exchanges, exchange_id).unwrap(); + let ctx = ExchangeCtx::get(&mut exchanges, exchange_id).unwrap(); let state = &mut ctx.state; @@ -397,11 +588,11 @@ impl<'a> Transport<'a> { // .. // } | ExchangeState::Complete { .. } // | ExchangeState::CompleteAcknowledge { .. } - ) || ctx.mrp.is_ack_ready(*self.matter.borrow()) + ) || ctx.mrp.is_ack_ready(*self.borrow()) }); if let Some(ctx) = ctx { - self.matter().notify_changed(); + self.notify_changed(); let state = &mut ctx.state; @@ -460,7 +651,7 @@ impl<'a> Transport<'a> { dest_tx.log("Sending packet"); self.pre_send(ctx, dest_tx)?; - self.matter().notify_changed(); + self.notify_changed(); return Ok(true); } @@ -500,7 +691,7 @@ impl<'a> Transport<'a> { let session = session_mgr.mut_by_index(sess_index).unwrap(); // Decrypt the message - session.recv(self.matter.epoch, rx)?; + session.recv(self.epoch, rx)?; // Get the exchange // TODO: Handle out of space @@ -513,7 +704,7 @@ impl<'a> Transport<'a> { )?; // Message Reliability Protocol - exch.mrp.recv(rx, self.matter.epoch)?; + exch.mrp.recv(rx, self.epoch)?; Ok((exch, new)) } @@ -576,11 +767,4 @@ impl<'a> Transport<'a> { Err(ErrorCode::NoExchange.into()) } } - - pub(crate) fn get<'r>( - exchanges: &'r mut heapless::Vec, - id: &ExchangeId, - ) -> Option<&'r mut ExchangeCtx> { - exchanges.iter_mut().find(|exchange| exchange.id == *id) - } } diff --git a/matter/src/transport/exchange.rs b/matter/src/transport/exchange.rs index fbe3d7aa..585e458a 100644 --- a/matter/src/transport/exchange.rs +++ b/matter/src/transport/exchange.rs @@ -1,13 +1,11 @@ -use embassy_sync::blocking_mutex::raw::NoopRawMutex; - use crate::{ acl::Accessor, error::{Error, ErrorCode}, + utils::select::Notification, Matter, }; use super::{ - core::Transport, mrp::ReliableMessage, network::Address, packet::Packet, @@ -16,8 +14,6 @@ use super::{ pub const MAX_EXCHANGES: usize = 8; -pub type Notification = embassy_sync::signal::Signal; - #[derive(Debug, PartialEq, Eq, Copy, Clone, Default)] pub(crate) enum Role { #[default] @@ -43,6 +39,15 @@ pub(crate) struct ExchangeCtx { pub(crate) state: ExchangeState, } +impl ExchangeCtx { + pub(crate) fn get<'r>( + exchanges: &'r mut heapless::Vec, + id: &ExchangeId, + ) -> Option<&'r mut ExchangeCtx> { + exchanges.iter_mut().find(|exchange| exchange.id == *id) + } +} + #[derive(Debug, Clone)] pub(crate) enum ExchangeState { Construction { @@ -144,7 +149,7 @@ impl SessionId { } pub struct Exchange<'a> { pub(crate) id: ExchangeId, - pub(crate) transport: &'a Transport<'a>, + pub(crate) matter: &'a Matter<'a>, pub(crate) notification: Notification, } @@ -153,21 +158,8 @@ impl<'a> Exchange<'a> { &self.id } - pub fn matter(&self) -> &Matter<'a> { - self.transport.matter() - } - - pub fn transport(&self) -> &Transport<'a> { - self.transport - } - pub fn accessor(&self) -> Result, Error> { - self.with_session(|sess| { - Ok(Accessor::for_session( - sess, - &self.transport.matter().acl_mgr, - )) - }) + self.with_session(|sess| Ok(Accessor::for_session(sess, &self.matter.acl_mgr))) } pub fn with_session_mut(&self, f: F) -> Result @@ -175,7 +167,7 @@ impl<'a> Exchange<'a> { F: FnOnce(&mut Session) -> Result, { self.with_ctx(|_self, ctx| { - let mut session_mgr = _self.transport.session_mgr.borrow_mut(); + let mut session_mgr = _self.matter.session_mgr.borrow_mut(); let sess_index = session_mgr .get( @@ -201,15 +193,11 @@ impl<'a> Exchange<'a> { where F: FnOnce(&mut SessionMgr) -> Result, { - let mut session_mgr = self.transport.session_mgr.borrow_mut(); + let mut session_mgr = self.matter.session_mgr.borrow_mut(); f(&mut session_mgr) } - pub async fn initiate(&mut self, fabric_id: u64, node_id: u64) -> Result, Error> { - self.transport.initiate(fabric_id, node_id).await - } - pub async fn acknowledge(&mut self) -> Result<(), Error> { let wait = self.with_ctx_mut(|_self, ctx| { if !matches!(ctx.state, ExchangeState::Active) { @@ -222,7 +210,7 @@ impl<'a> Exchange<'a> { ctx.state = ExchangeState::Acknowledge { notification: &_self.notification as *const _, }; - _self.transport.send_notification.signal(()); + _self.matter.send_notification.signal(()); Ok(true) } @@ -249,7 +237,7 @@ impl<'a> Exchange<'a> { rx: rx as *mut _, notification: &_self.notification as *const _, }; - _self.transport.send_notification.signal(()); + _self.matter.send_notification.signal(()); Ok(()) })?; @@ -275,7 +263,7 @@ impl<'a> Exchange<'a> { tx: tx as *const _, notification: &_self.notification as *const _, }; - _self.transport.send_notification.signal(()); + _self.matter.send_notification.signal(()); Ok(()) })?; @@ -289,9 +277,9 @@ impl<'a> Exchange<'a> { where F: FnOnce(&Self, &ExchangeCtx) -> Result, { - let mut exchanges = self.transport.exchanges.borrow_mut(); + let mut exchanges = self.matter.exchanges.borrow_mut(); - let exchange = Transport::get(&mut exchanges, &self.id).ok_or(ErrorCode::NoExchange)?; // TODO + let exchange = ExchangeCtx::get(&mut exchanges, &self.id).ok_or(ErrorCode::NoExchange)?; // TODO f(self, exchange) } @@ -300,9 +288,9 @@ impl<'a> Exchange<'a> { where F: FnOnce(&mut Self, &mut ExchangeCtx) -> Result, { - let mut exchanges = self.transport.exchanges.borrow_mut(); + let mut exchanges = self.matter.exchanges.borrow_mut(); - let exchange = Transport::get(&mut exchanges, &self.id).ok_or(ErrorCode::NoExchange)?; // TODO + let exchange = ExchangeCtx::get(&mut exchanges, &self.id).ok_or(ErrorCode::NoExchange)?; // TODO f(self, exchange) } @@ -312,7 +300,7 @@ impl<'a> Drop for Exchange<'a> { fn drop(&mut self) { let _ = self.with_ctx_mut(|_self, ctx| { ctx.state = ExchangeState::Closed; - _self.transport.send_notification.signal(()); + _self.matter.send_notification.signal(()); Ok(()) }); diff --git a/matter/src/transport/mod.rs b/matter/src/transport/mod.rs index 6c5601e7..e968adb1 100644 --- a/matter/src/transport/mod.rs +++ b/matter/src/transport/mod.rs @@ -24,6 +24,5 @@ pub mod packet; pub mod pipe; pub mod plain_hdr; pub mod proto_hdr; -pub mod runner; pub mod session; pub mod udp; diff --git a/matter/src/transport/runner.rs b/matter/src/transport/runner.rs deleted file mode 100644 index 554721b0..00000000 --- a/matter/src/transport/runner.rs +++ /dev/null @@ -1,319 +0,0 @@ -/* - * - * Copyright (c) 2020-2022 Project CHIP Authors - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -use core::{mem::MaybeUninit, pin::pin}; - -use embassy_futures::select::{select, select_slice, Either}; -use embassy_sync::{blocking_mutex::raw::NoopRawMutex, channel::Channel}; - -use log::{error, info}; - -use crate::{data_model::objects::DataModelHandler, CommissioningData, Matter}; -use crate::{error::Error, transport::packet::MAX_RX_BUF_SIZE, utils::select::EitherUnwrap}; - -use super::{ - core::Transport, - exchange::{Notification, MAX_EXCHANGES}, - packet::{MAX_RX_STATUS_BUF_SIZE, MAX_TX_BUF_SIZE}, - pipe::{Chunk, Pipe}, -}; - -type TxBuf = MaybeUninit<[u8; MAX_TX_BUF_SIZE]>; -type RxBuf = MaybeUninit<[u8; MAX_RX_BUF_SIZE]>; -type SxBuf = MaybeUninit<[u8; MAX_RX_STATUS_BUF_SIZE]>; - -struct PacketPools { - tx: [TxBuf; MAX_EXCHANGES], - rx: [RxBuf; MAX_EXCHANGES], - sx: [SxBuf; MAX_EXCHANGES], -} - -impl PacketPools { - const TX_ELEM: TxBuf = MaybeUninit::uninit(); - const RX_ELEM: RxBuf = MaybeUninit::uninit(); - const SX_ELEM: SxBuf = MaybeUninit::uninit(); - - const TX_INIT: [TxBuf; MAX_EXCHANGES] = [Self::TX_ELEM; MAX_EXCHANGES]; - const RX_INIT: [RxBuf; MAX_EXCHANGES] = [Self::RX_ELEM; MAX_EXCHANGES]; - const SX_INIT: [SxBuf; MAX_EXCHANGES] = [Self::SX_ELEM; MAX_EXCHANGES]; - - #[inline(always)] - pub const fn new() -> Self { - Self { - tx: Self::TX_INIT, - rx: Self::RX_INIT, - sx: Self::SX_INIT, - } - } -} - -#[cfg(any(feature = "std", feature = "embassy-net"))] -pub struct AllUdpBuffers { - transport: TransportUdpBuffers, - mdns: crate::mdns::MdnsUdpBuffers, -} - -#[cfg(any(feature = "std", feature = "embassy-net"))] -impl AllUdpBuffers { - #[inline(always)] - pub const fn new() -> Self { - Self { - transport: TransportUdpBuffers::new(), - mdns: crate::mdns::MdnsUdpBuffers::new(), - } - } -} - -#[cfg(any(feature = "std", feature = "embassy-net"))] -pub struct TransportUdpBuffers { - udp: crate::transport::udp::UdpBuffers, - tx_buf: TxBuf, - rx_buf: RxBuf, -} - -#[cfg(any(feature = "std", feature = "embassy-net"))] -impl TransportUdpBuffers { - #[inline(always)] - pub const fn new() -> Self { - Self { - udp: crate::transport::udp::UdpBuffers::new(), - tx_buf: core::mem::MaybeUninit::uninit(), - rx_buf: core::mem::MaybeUninit::uninit(), - } - } -} - -/// This struct implements an executor-agnostic option to run the Matter transport stack end-to-end. -/// -/// Since it is not possible to use executor tasks spawning in an executor-agnostic way (yet), -/// the async loops are arranged as one giant future. Therefore, the cost is a slightly slower execution -/// due to the generated future being relatively big and deeply nested. -/// -/// Users are free to implement their own async execution loop, by utilizing the `Transport` -/// struct directly with their async executor of choice. -pub struct TransportRunner<'a> { - transport: Transport<'a>, - pools: PacketPools, -} - -impl<'a> TransportRunner<'a> { - #[inline(always)] - pub fn new(matter: &'a Matter<'a>) -> Self { - Self::wrap(Transport::new(matter)) - } - - #[inline(always)] - pub const fn wrap(transport: Transport<'a>) -> Self { - Self { - transport, - pools: PacketPools::new(), - } - } - - pub fn transport(&self) -> &Transport { - &self.transport - } - - #[cfg(any(feature = "std", feature = "embassy-net"))] - pub async fn run_udp_all( - &mut self, - stack: &crate::transport::network::NetworkStack, - mdns: &crate::mdns::MdnsService<'_>, - buffers: &mut AllUdpBuffers, - dev_comm: CommissioningData, - handler: &H, - ) -> Result<(), Error> - where - D: crate::transport::network::NetworkStackDriver, - H: DataModelHandler, - { - let mut mdns_runner = crate::mdns::MdnsRunner::new(mdns); - - let mut mdns = pin!(mdns_runner.run_udp(stack, &mut buffers.mdns)); - let mut transport = pin!(self.run_udp(stack, &mut buffers.transport, dev_comm, handler)); - - embassy_futures::select::select(&mut mdns, &mut transport) - .await - .unwrap() - } - - #[cfg(any(feature = "std", feature = "embassy-net"))] - pub async fn run_udp( - &mut self, - stack: &crate::transport::network::NetworkStack, - buffers: &mut TransportUdpBuffers, - dev_comm: CommissioningData, - handler: &H, - ) -> Result<(), Error> - where - D: crate::transport::network::NetworkStackDriver, - H: DataModelHandler, - { - let udp = crate::transport::udp::UdpListener::new( - stack, - crate::transport::network::SocketAddr::new( - crate::transport::network::IpAddr::V6( - crate::transport::network::Ipv6Addr::UNSPECIFIED, - ), - self.transport.matter().port, - ), - &mut buffers.udp, - ) - .await?; - - let tx_pipe = Pipe::new(unsafe { buffers.tx_buf.assume_init_mut() }); - let rx_pipe = Pipe::new(unsafe { buffers.rx_buf.assume_init_mut() }); - - let tx_pipe = &tx_pipe; - let rx_pipe = &rx_pipe; - let udp = &udp; - - let mut tx = pin!(async move { - loop { - { - let mut data = tx_pipe.data.lock().await; - - if let Some(chunk) = data.chunk { - udp.send(chunk.addr.unwrap_udp(), &data.buf[chunk.start..chunk.end]) - .await?; - data.chunk = None; - tx_pipe.data_consumed_notification.signal(()); - } - } - - tx_pipe.data_supplied_notification.wait().await; - } - }); - - let mut rx = pin!(async move { - loop { - { - let mut data = rx_pipe.data.lock().await; - - if data.chunk.is_none() { - let (len, addr) = udp.recv(data.buf).await?; - - data.chunk = Some(Chunk { - start: 0, - end: len, - addr: crate::transport::network::Address::Udp(addr), - }); - rx_pipe.data_supplied_notification.signal(()); - } - } - - rx_pipe.data_consumed_notification.wait().await; - } - }); - - let mut run = pin!(async move { self.run(tx_pipe, rx_pipe, dev_comm, handler).await }); - - embassy_futures::select::select3(&mut tx, &mut rx, &mut run) - .await - .unwrap() - } - - pub async fn run( - &mut self, - tx_pipe: &Pipe<'_>, - rx_pipe: &Pipe<'_>, - dev_comm: CommissioningData, - handler: &H, - ) -> Result<(), Error> - where - H: DataModelHandler, - { - info!("Running Matter transport"); - - let buf = unsafe { self.pools.rx[0].assume_init_mut() }; - - if self.transport.matter().start_comissioning(dev_comm, buf)? { - info!("Comissioning started"); - } - - let construction_notification = Notification::new(); - - let mut rx = pin!(Self::handle_rx( - &self.transport, - &mut self.pools, - rx_pipe, - &construction_notification, - handler - )); - let mut tx = pin!(self.transport.handle_tx(tx_pipe)); - - select(&mut rx, &mut tx).await.unwrap() - } - - #[inline(always)] - async fn handle_rx( - transport: &Transport<'_>, - pools: &mut PacketPools, - rx_pipe: &Pipe<'_>, - construction_notification: &Notification, - handler: &H, - ) -> Result<(), Error> - where - H: DataModelHandler, - { - info!("Creating queue for {} exchanges", 1); - - let channel = Channel::::new(); - - info!("Creating {} handlers", MAX_EXCHANGES); - let mut handlers = heapless::Vec::<_, MAX_EXCHANGES>::new(); - - info!("Handlers size: {}", core::mem::size_of_val(&handlers)); - - // Unsafely allow mutable aliasing in the packet pools by different indices - let pools: *mut PacketPools = pools; - - for index in 0..MAX_EXCHANGES { - let channel = &channel; - let handler_id = index; - - let pools = unsafe { pools.as_mut() }.unwrap(); - - let tx_buf = unsafe { pools.tx[handler_id].assume_init_mut() }; - let rx_buf = unsafe { pools.rx[handler_id].assume_init_mut() }; - let sx_buf = unsafe { pools.sx[handler_id].assume_init_mut() }; - - handlers - .push( - transport - .exchange_handler(tx_buf, rx_buf, sx_buf, handler_id, channel, handler), - ) - .map_err(|_| ()) - .unwrap(); - } - - let mut rx = - pin!(transport.handle_rx_multiplex(rx_pipe, &construction_notification, &channel)); - - let result = select(&mut rx, select_slice(&mut handlers)).await; - - if let Either::First(result) = result { - if let Err(e) = &result { - error!("Exitting RX loop due to an error: {:?}", e); - } - - result?; - } - - Ok(()) - } -} diff --git a/matter/src/transport/session.rs b/matter/src/transport/session.rs index c421244e..41fbc497 100644 --- a/matter/src/transport/session.rs +++ b/matter/src/transport/session.rs @@ -306,6 +306,11 @@ impl SessionMgr { } } + pub fn reset(&mut self) { + self.sessions.clear(); + self.next_sess_id = 1; + } + pub fn mut_by_index(&mut self, index: usize) -> Option<&mut Session> { self.sessions.get_mut(index).and_then(Option::as_mut) } diff --git a/matter/src/transport/udp.rs b/matter/src/transport/udp.rs index 3d27d2da..602dc120 100644 --- a/matter/src/transport/udp.rs +++ b/matter/src/transport/udp.rs @@ -16,10 +16,10 @@ */ #[cfg(all(feature = "std", not(feature = "embassy-net")))] -pub use async_io::*; +pub use self::async_io::*; #[cfg(feature = "embassy-net")] -pub use embassy_net::*; +pub use self::embassy_net::*; #[cfg(feature = "std")] pub mod async_io { @@ -88,7 +88,7 @@ pub mod async_io { #[cfg(target_os = "espidf")] { fn esp_setsockopt( - socket: &mut UdpSocket, + socket: &UdpSocket, proto: u32, option: u32, value: T, @@ -119,7 +119,7 @@ pub mod async_io { }; esp_setsockopt( - &mut self.0, + &mut self.0.get_ref(), esp_idf_sys::IPPROTO_IP, esp_idf_sys::IP_ADD_MEMBERSHIP, mreq, diff --git a/matter/tests/common/im_engine.rs b/matter/tests/common/im_engine.rs index 8efb2c90..1cd26bd9 100644 --- a/matter/tests/common/im_engine.rs +++ b/matter/tests/common/im_engine.rs @@ -48,16 +48,15 @@ use matter::{ secure_channel::{self, common::PROTO_ID_SECURE_CHANNEL, spake2p::VerifierData}, tlv::{TLVWriter, TagType, ToTLV}, transport::{ - exchange::Notification, + core::PacketBuffers, packet::{Packet, MAX_RX_BUF_SIZE, MAX_TX_BUF_SIZE}, pipe::Pipe, - runner::TransportRunner, }, transport::{ network::Address, session::{CaseDetails, CloneData, NocCatIds, SessionMode}, }, - utils::select::EitherUnwrap, + utils::select::{EitherUnwrap, Notification}, CommissioningData, Matter, MATTER_PORT, }; @@ -248,7 +247,7 @@ impl<'a> ImEngine<'a> { input: &[&ImInput], out: &mut heapless::Vec, ) -> Result<(), Error> { - let mut runner = TransportRunner::new(&self.matter); + self.matter.reset_transport(); let clone_data = CloneData::new( IM_ENGINE_REMOTE_PEER_ID, @@ -259,8 +258,8 @@ impl<'a> ImEngine<'a> { SessionMode::Case(CaseDetails::new(1, &self.cat_ids)), ); - let sess_idx = runner - .transport() + let sess_idx = self + .matter .session_mgr .borrow_mut() .clone_session(&clone_data) @@ -281,10 +280,9 @@ impl<'a> ImEngine<'a> { let rx_pipe_buf = &mut rx_pipe_buf; let handler = &handler; - let runner = &mut runner; - let mut msg_ctr = runner - .transport() + let mut msg_ctr = self + .matter .session_mgr .borrow_mut() .mut_by_index(sess_idx) @@ -294,9 +292,13 @@ impl<'a> ImEngine<'a> { let resp_notif = Notification::new(); let resp_notif = &resp_notif; + let mut buffers = PacketBuffers::new(); + let buffers = &mut buffers; + embassy_futures::block_on(async move { select3( - runner.run( + self.matter.run_piped( + buffers, tx_pipe, rx_pipe, CommissioningData { diff --git a/matter/tests/data_model/timed_requests.rs b/matter/tests/data_model/timed_requests.rs index e4eb960e..c2555062 100644 --- a/matter/tests/data_model/timed_requests.rs +++ b/matter/tests/data_model/timed_requests.rs @@ -100,7 +100,7 @@ fn test_timed_cmd_success() { ImEngine::timed_commands( input, &TimedInvResponse::TransactionSuccess(expected), - 400, + 2000, 0, true, ); @@ -130,7 +130,7 @@ fn test_timed_cmd_timedout_mismatch() { ImEngine::timed_commands( input, &TimedInvResponse::TransactionError(IMStatusCode::TimedRequestMisMatch), - 400, + 2000, 0, false, ); diff --git a/matter_macro_derive/Cargo.toml b/matter_macro_derive/Cargo.toml index 163ff502..5bf29bb2 100644 --- a/matter_macro_derive/Cargo.toml +++ b/matter_macro_derive/Cargo.toml @@ -2,6 +2,7 @@ name = "matter_macro_derive" version = "0.1.0" edition = "2021" +license = "Apache-2.0" # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html [lib] diff --git a/tools/tlv_tool/Cargo.toml b/tools/tlv_tool/Cargo.toml index f8c1e232..f4c10351 100644 --- a/tools/tlv_tool/Cargo.toml +++ b/tools/tlv_tool/Cargo.toml @@ -1,7 +1,8 @@ [package] name = "tlv_tool" version = "0.1.0" -edition = "2018" +edition = "2021" +license = "Apache-2.0" # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html From 91e13292da7849c234051c884835a93298aad040 Mon Sep 17 00:00:00 2001 From: ivmarkov Date: Sat, 22 Jul 2023 07:00:53 +0000 Subject: [PATCH 72/72] Remove the note referring to the no_std and sequential branches --- README.md | 34 ---------------------------------- 1 file changed, 34 deletions(-) diff --git a/README.md b/README.md index 3de42e9a..327d4f75 100644 --- a/README.md +++ b/README.md @@ -5,40 +5,6 @@ [![Test Linux (OpenSSL)](https://github.com/project-chip/matter-rs/actions/workflows/test-linux-openssl.yml/badge.svg)](https://github.com/project-chip/matter-rs/actions/workflows/test-linux-openssl.yml) [![Test Linux (mbedTLS)](https://github.com/project-chip/matter-rs/actions/workflows/test-linux-mbedtls.yml/badge.svg)](https://github.com/project-chip/matter-rs/actions/workflows/test-linux-mbedtls.yml) -## Important Note - -All development work is now ongoing in two other branches ([no_std](https://github.com/project-chip/matter-rs/tree/no_std) and [sequential](https://github.com/project-chip/matter-rs/tree/sequential) - explained below). The plan is one of these two branches to become the new `main`. - -We highly encourage users to try out both of these branches (there is a working `onoff_light` example in both) and provide feedback. - -### [no_std](https://github.com/project-chip/matter-rs/tree/no_std) - -The purpose of this branch - as the name suggests - is to introduce `no_std` compatibility to the `matter-rs` library, so that it is possible to target constrained environments like MCUs which more often than not have no support for the Rust Standard library (threads, network sockets, filesystem and so on). - -We have been successful in this endeavour. The library now only requires Rust `core` and runs on e.g. ESP32 baremental Rust targets. -When `matter-rs` is used on targets that do not support the Rust Standard Library, user is expected to provide the following: - -- A `rand` function that can fill a `&[u8]` slice with random data -- An `epoch` function (a "current time" utility); note that since this utility is only used for measuring timeouts, it is OK to provide a function that e.g. measures elapsed millis since system boot, rather than something that tries to adhere to the UNIX epoch (1/1/1970) -- An MCU-specific UDP stack that the user would need to connect to the `matter-rs` library - -Besides just having `no_std` compatibility, the `no_std` branch does not need an allocator. I.e. all structures internal to the `matter-rs` librarty are statically allocated. - -Last but not least, the `no_std` branch by itself does **not** do any IO. In other words, it is "compute only" (as in, "give me a network packet and I'll produce one or more that you have to send; how you receive/send those is up to you"). Ditto for persisting fabrics and ACLs - it is up to the user to listen the matter stack for changes to those and persist. - -### [sequential](https://github.com/project-chip/matter-rs/tree/sequential) - -The `sequential` branch builds on top of the work implemented in the `no_std` branch by utilizing code implemented as `async` functions and methods. Committing to `async` has multiple benefits: - -- (Internal for the library) We were able to turn several explicit state machines into implicit ones (after all, `async` is primarily about generating state machines automatically based on "sequential" user codee that uses the async/await language constructs - hence the name of the branch) -- (External, for the user) The ergonomics of the Exchange API in this branch (in other words, the "transport aspect of the Matter CSA spec) is much better, approaching that of dealing with regular TCP/IP sockets in the Rust Standard Library. This is only possible by utilizing async functions and methods, because - let's not forget - `matter-rs` needs to run on MCUs where native threading and task scheduling capabilities might not even exist, hence "sequentially-looking" request/response interaction can only be expressed asynchronously, or with explicit state machines. -- Certain pending concepts are much easier to implement via async functions and methods: -- Re-sending packets which were not acknowledged by the receiver yet (the MRP protocol as per the Matter spec) -- The "initiator" side of an exchange (think client clusters) -- This branch provides facilities to implement asynchronous read, write and invoke handling for server clusters, which is beneficial in certain scenarios (i.e. brdige devices) - -The `async` metaphor however comes with a bit higher memory usage, due to not enough optimizations being implemented yet in the rust language when the async code is transpiled to state machines. - ## Build ### Building the library