Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

state.rs: Harden C enum casting #105

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
59 changes: 48 additions & 11 deletions pyth-sdk-solana/src/state.rs
Original file line number Diff line number Diff line change
@@ -1,13 +1,23 @@
//! Structures and functions for interacting with Solana on-chain account data.
//!
//! NOTE(2023-05-12): enums defined here use u32 corresponding with
//! uint32_t that's currently used in pyth-client's oracle.h struct
//! definitions. Enum correctness is validated with bytemuck's checked
//! casting functions and derive(CheckedBitPattern) on the relevant
//! enums.

use borsh::{
BorshDeserialize,
BorshSerialize,
};
use bytemuck::{
use bytemuck::checked::{
cast_slice,
from_bytes,
try_cast_slice,
CheckedCastError,
};
use bytemuck::{
CheckedBitPattern,
Pod,
PodCastError,
Zeroable,
Expand Down Expand Up @@ -46,10 +56,11 @@ pub const PROD_ATTR_SIZE: usize = PROD_ACCT_SIZE - PROD_HDR_SIZE;
BorshDeserialize,
serde::Serialize,
serde::Deserialize,
CheckedBitPattern,
)]
#[repr(C)]
#[repr(u32)]
pub enum AccountType {
Unknown,
Unknown = 0,
Mapping,
Product,
Price,
Expand All @@ -73,10 +84,11 @@ impl Default for AccountType {
BorshDeserialize,
serde::Serialize,
serde::Deserialize,
CheckedBitPattern,
)]
#[repr(C)]
#[repr(u32)]
pub enum CorpAction {
NoCorpAct,
NoCorpAct = 0,
}

impl Default for CorpAction {
Expand All @@ -97,10 +109,11 @@ impl Default for CorpAction {
BorshDeserialize,
serde::Serialize,
serde::Deserialize,
CheckedBitPattern,
)]
#[repr(C)]
#[repr(u32)]
pub enum PriceType {
Unknown,
Unknown = 0,
Price,
}

Expand All @@ -122,11 +135,12 @@ impl Default for PriceType {
BorshDeserialize,
serde::Serialize,
serde::Deserialize,
CheckedBitPattern,
)]
#[repr(C)]
#[repr(u32)]
pub enum PriceStatus {
/// The price feed is not currently updating for an unknown reason.
Unknown,
Unknown = 0,
/// The price feed is updating as expected.
Trading,
/// The price feed is not currently updating because trading in the product has been halted.
Expand Down Expand Up @@ -410,14 +424,14 @@ impl PriceAccount {
}
}

fn load<T: Pod>(data: &[u8]) -> Result<&T, PodCastError> {
fn load<T: Pod>(data: &[u8]) -> Result<&T, CheckedCastError> {
let size = size_of::<T>();
if data.len() >= size {
Ok(from_bytes(cast_slice::<u8, u8>(try_cast_slice(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unrelated : what is the point of all these cast_slice and try_cast_slice seems redundant.

&data[0..size],
)?)))
} else {
Err(PodCastError::SizeMismatch)
Err(CheckedCastError::PodCastError(PodCastError::SizeMismatch))
}
}

Expand Down Expand Up @@ -502,6 +516,7 @@ fn get_attr_str(buf: &[u8]) -> (&str, &[u8]) {

#[cfg(test)]
mod test {
use bytemuck::checked::try_from_bytes;
use pyth_sdk::{
Identifier,
Price,
Expand Down Expand Up @@ -737,4 +752,26 @@ mod test {

assert_eq!(price_account.get_price_no_older_than(&clock, 1), None);
}

/// Ensure that bytemuck::checked::* casting functions accept
/// valid bytes
#[test]
fn test_happy_recognized_price_status() {
let happy_status_bytes = 2u32.to_le_bytes();

let happy_status_result = try_from_bytes::<PriceStatus>(happy_status_bytes.as_slice());

assert_eq!(happy_status_result, Ok(&PriceStatus::Halted));
}

/// Ensure that bytemuck::checked::* casting functions reject
/// invalid bytes
#[test]
fn test_sad_unrecognized_price_status() {
let sad_status_bytes = 42_000u32.to_le_bytes();

let sad_status_result = try_from_bytes::<PriceStatus>(sad_status_bytes.as_slice());

assert!(sad_status_result.is_err());
}
}