diff --git a/benches/benchmark.rs b/benches/benchmark.rs index 1ebeb64..6f82a15 100644 --- a/benches/benchmark.rs +++ b/benches/benchmark.rs @@ -6,7 +6,13 @@ fn criterion_benchmark(c: &mut Bencher) { c.iter(|| { let data = include_bytes!("../tests/data/certificate.chain.pem"); let mut reader = BufReader::new(&data[..]); - assert_eq!(rustls_pemfile::certs(&mut reader).unwrap().len(), 3); + assert_eq!( + rustls_pemfile::certs(&mut reader) + .collect::, _>>() + .unwrap() + .len(), + 3 + ); }); } diff --git a/src/lib.rs b/src/lib.rs index 44a98c3..6fddb7e 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -54,22 +54,21 @@ use pki_types::{ /// --- Legacy APIs: use std::io; +use std::iter; /// Extract all the certificates from `rd`, and return a vec of byte vecs /// containing the der-format contents. /// /// This function does not fail if there are no certificates in the file -- /// it returns an empty vector. -pub fn certs(rd: &mut dyn io::BufRead) -> Result>, io::Error> { - let mut certs = Vec::new(); - - loop { - match read_one(rd)? { - None => return Ok(certs), - Some(Item::X509Certificate(cert)) => certs.push(cert), - _ => {} - }; - } +pub fn certs( + rd: &mut dyn io::BufRead, +) -> impl Iterator, io::Error>> + '_ { + iter::from_fn(move || read_one(rd).transpose()).filter_map(|item| match item { + Ok(Item::X509Certificate(cert)) => Some(Ok(cert)), + Err(err) => Some(Err(err)), + _ => None, + }) } /// Extract all the certificate revocation lists (CRLs) from `rd`, and return a vec of byte vecs @@ -79,16 +78,12 @@ pub fn certs(rd: &mut dyn io::BufRead) -> Result>, i /// it returns an empty vector. pub fn crls( rd: &mut dyn io::BufRead, -) -> Result>, io::Error> { - let mut crls = Vec::new(); - - loop { - match read_one(rd)? { - None => return Ok(crls), - Some(Item::Crl(crl)) => crls.push(crl), - _ => {} - }; - } +) -> impl Iterator, io::Error>> + '_ { + iter::from_fn(move || read_one(rd).transpose()).filter_map(|item| match item { + Ok(Item::Crl(crl)) => Some(Ok(crl)), + Err(err) => Some(Err(err)), + _ => None, + }) } /// Extract all RSA private keys from `rd`, and return a vec of byte vecs @@ -98,16 +93,12 @@ pub fn crls( /// empty vector. pub fn rsa_private_keys( rd: &mut dyn io::BufRead, -) -> Result>, io::Error> { - let mut keys = Vec::new(); - - loop { - match read_one(rd)? { - None => return Ok(keys), - Some(Item::RSAKey(key)) => keys.push(key), - _ => {} - }; - } +) -> impl Iterator, io::Error>> + '_ { + iter::from_fn(move || read_one(rd).transpose()).filter_map(|item| match item { + Ok(Item::RSAKey(key)) => Some(Ok(key)), + Err(err) => Some(Err(err)), + _ => None, + }) } /// Extract all PKCS8-encoded private keys from `rd`, and return a vec of @@ -117,16 +108,12 @@ pub fn rsa_private_keys( /// empty vector. pub fn pkcs8_private_keys( rd: &mut dyn io::BufRead, -) -> Result>, io::Error> { - let mut keys = Vec::new(); - - loop { - match read_one(rd)? { - None => return Ok(keys), - Some(Item::PKCS8Key(key)) => keys.push(key), - _ => {} - }; - } +) -> impl Iterator, io::Error>> + '_ { + iter::from_fn(move || read_one(rd).transpose()).filter_map(|item| match item { + Ok(Item::PKCS8Key(key)) => Some(Ok(key)), + Err(err) => Some(Err(err)), + _ => None, + }) } /// Extract all SEC1-encoded EC private keys from `rd`, and return a vec of @@ -136,14 +123,10 @@ pub fn pkcs8_private_keys( /// empty vector. pub fn ec_private_keys( rd: &mut dyn io::BufRead, -) -> Result>, io::Error> { - let mut keys = Vec::new(); - - loop { - match read_one(rd)? { - None => return Ok(keys), - Some(Item::ECKey(key)) => keys.push(key), - _ => {} - }; - } +) -> impl Iterator, io::Error>> + '_ { + iter::from_fn(move || read_one(rd).transpose()).filter_map(|item| match item { + Ok(Item::ECKey(key)) => Some(Ok(key)), + Err(err) => Some(Err(err)), + _ => None, + }) } diff --git a/src/pemfile.rs b/src/pemfile.rs index 7159c70..398752c 100644 --- a/src/pemfile.rs +++ b/src/pemfile.rs @@ -1,4 +1,5 @@ use std::io::{self, ErrorKind}; +use std::iter; use pki_types::{ CertificateDer, CertificateRevocationListDer, PrivatePkcs1KeyDer, PrivatePkcs8KeyDer, @@ -122,15 +123,8 @@ pub fn read_one(rd: &mut dyn io::BufRead) -> Result, io::Error> { } /// Extract and return all PEM sections by reading `rd`. -pub fn read_all(rd: &mut dyn io::BufRead) -> Result, io::Error> { - let mut v = Vec::::new(); - - loop { - match read_one(rd)? { - None => return Ok(v), - Some(item) => v.push(item), - } - } +pub fn read_all(rd: &mut dyn io::BufRead) -> impl Iterator> + '_ { + iter::from_fn(move || read_one(rd).transpose()) } mod base64 { diff --git a/src/tests.rs b/src/tests.rs index c17246c..d8dac65 100644 --- a/src/tests.rs +++ b/src/tests.rs @@ -2,7 +2,7 @@ mod unit { fn check(data: &[u8]) -> Result, std::io::Error> { let mut reader = std::io::BufReader::new(data); - crate::read_all(&mut reader) + crate::read_all(&mut reader).collect() } #[test] diff --git a/tests/integration.rs b/tests/integration.rs index 71c7345..e27335a 100644 --- a/tests/integration.rs +++ b/tests/integration.rs @@ -7,7 +7,10 @@ fn test_rsa_private_keys() { let mut reader = BufReader::new(&data[..]); assert_eq!( - rustls_pemfile::rsa_private_keys(&mut reader).unwrap().len(), + rustls_pemfile::rsa_private_keys(&mut reader) + .collect::, _>>() + .unwrap() + .len(), 2 ); } @@ -17,21 +20,39 @@ fn test_certs() { let data = include_bytes!("data/certificate.chain.pem"); let mut reader = BufReader::new(&data[..]); - assert_eq!(rustls_pemfile::certs(&mut reader).unwrap().len(), 3); + assert_eq!( + rustls_pemfile::certs(&mut reader) + .collect::, _>>() + .unwrap() + .len(), + 3 + ); } #[test] fn test_certs_with_binary() { let data = include_bytes!("data/gunk.pem"); let mut reader = BufReader::new(&data[..]); - assert_eq!(rustls_pemfile::certs(&mut reader).unwrap().len(), 2); + assert_eq!( + rustls_pemfile::certs(&mut reader) + .collect::, _>>() + .unwrap() + .len(), + 2 + ); } #[test] fn test_crls() { let data = include_bytes!("data/crl.pem"); let mut reader = BufReader::new(&data[..]); - assert_eq!(rustls_pemfile::crls(&mut reader).unwrap().len(), 1); + assert_eq!( + rustls_pemfile::crls(&mut reader) + .collect::, _>>() + .unwrap() + .len(), + 1 + ); } #[test] @@ -41,6 +62,7 @@ fn test_pkcs8() { assert_eq!( rustls_pemfile::pkcs8_private_keys(&mut reader) + .collect::, _>>() .unwrap() .len(), 2 @@ -52,7 +74,9 @@ fn test_sec1() { let data = include_bytes!("data/nistp256key.pem"); let mut reader = BufReader::new(&data[..]); - let items = rustls_pemfile::read_all(&mut reader).unwrap(); + let items = rustls_pemfile::read_all(&mut reader) + .collect::, _>>() + .unwrap(); assert_eq!(items.len(), 1); assert!(matches!(items[0], rustls_pemfile::Item::ECKey(_))); } @@ -78,7 +102,9 @@ fn test_sec1_vs_pkcs8() { let data = include_bytes!("data/nistp256key.pem"); let mut reader = BufReader::new(&data[..]); - let items = rustls_pemfile::read_all(&mut reader).unwrap(); + let items = rustls_pemfile::read_all(&mut reader) + .collect::, _>>() + .unwrap(); assert!(matches!(items[0], rustls_pemfile::Item::ECKey(_))); println!("sec1 {:?}", items); } @@ -86,7 +112,9 @@ fn test_sec1_vs_pkcs8() { let data = include_bytes!("data/nistp256key.pkcs8.pem"); let mut reader = BufReader::new(&data[..]); - let items = rustls_pemfile::read_all(&mut reader).unwrap(); + let items = rustls_pemfile::read_all(&mut reader) + .collect::, _>>() + .unwrap(); assert!(matches!(items[0], rustls_pemfile::Item::PKCS8Key(_))); println!("p8 {:?}", items); } @@ -97,7 +125,9 @@ fn parse_in_order() { let data = include_bytes!("data/zen.pem"); let mut reader = BufReader::new(&data[..]); - let items = rustls_pemfile::read_all(&mut reader).unwrap(); + let items = rustls_pemfile::read_all(&mut reader) + .collect::, _>>() + .unwrap(); assert_eq!(items.len(), 9); assert!(matches!(items[0], rustls_pemfile::Item::X509Certificate(_))); assert!(matches!(items[1], rustls_pemfile::Item::X509Certificate(_)));