Skip to content

Commit

Permalink
Fixing clippy errors
Browse files Browse the repository at this point in the history
  • Loading branch information
coreylowman committed Sep 14, 2023
1 parent fbe96e5 commit 239c21d
Show file tree
Hide file tree
Showing 8 changed files with 34 additions and 65 deletions.
14 changes: 7 additions & 7 deletions dfdx-nn-core/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,7 @@ pub trait SaveSafeTensors {
let data = tensors.iter().map(|(k, dtype, shape, data)| {
(
k.clone(),
safetensors::tensor::TensorView::new(dtype.clone(), shape.clone(), data).unwrap(),
safetensors::tensor::TensorView::new(*dtype, shape.clone(), data).unwrap(),
)
});

Expand All @@ -178,18 +178,18 @@ pub trait LoadSafeTensors {
self.read_safetensors("", &tensors)
}

fn read_safetensors<'a>(
fn read_safetensors(
&mut self,
location: &str,
tensors: &safetensors::SafeTensors<'a>,
tensors: &safetensors::SafeTensors,
) -> Result<(), safetensors::SafeTensorError>;
}

impl<S: Shape, E: Dtype, D: Device<E>, T> LoadSafeTensors for Tensor<S, E, D, T> {
fn read_safetensors<'a>(
fn read_safetensors(
&mut self,
location: &str,
tensors: &safetensors::SafeTensors<'a>,
tensors: &safetensors::SafeTensors,
) -> Result<(), safetensors::SafeTensorError> {
self.load_safetensor(tensors, location)
}
Expand Down Expand Up @@ -230,10 +230,10 @@ macro_rules! unit_safetensors {
}

impl LoadSafeTensors for $Ty {
fn read_safetensors<'a>(
fn read_safetensors(
&mut self,
location: &str,
tensors: &safetensors::SafeTensors<'a>,
tensors: &safetensors::SafeTensors,
) -> Result<(), safetensors::SafeTensorError> {
#[allow(unused_imports)]
use dfdx::dtypes::FromLeBytes;
Expand Down
4 changes: 2 additions & 2 deletions dfdx-nn-core/src/tuples.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,10 @@ macro_rules! tuple_impls {
}

impl<$($name: crate::LoadSafeTensors, )+> crate::LoadSafeTensors for ($($name,)+) {
fn read_safetensors<'a>(
fn read_safetensors(
&mut self,
location: &str,
tensors: &safetensors::SafeTensors<'a>,
tensors: &safetensors::SafeTensors,
) -> Result<(), safetensors::SafeTensorError> {
$(self.$idx.read_safetensors(&format!("{location}{}.", $idx), tensors)?;)+
Ok(())
Expand Down
4 changes: 2 additions & 2 deletions dfdx-nn-core/src/vecs.rs
Original file line number Diff line number Diff line change
Expand Up @@ -54,10 +54,10 @@ impl<T: crate::SaveSafeTensors> crate::SaveSafeTensors for Vec<T> {
}

impl<T: crate::LoadSafeTensors> crate::LoadSafeTensors for Vec<T> {
fn read_safetensors<'a>(
fn read_safetensors(
&mut self,
location: &str,
tensors: &safetensors::SafeTensors<'a>,
tensors: &safetensors::SafeTensors,
) -> Result<(), safetensors::SafeTensorError> {
for (i, t) in self.iter_mut().enumerate() {
t.read_safetensors(&format!("{location}{i}."), tensors)?;
Expand Down
72 changes: 18 additions & 54 deletions dfdx-nn-derives/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -544,29 +544,17 @@ pub fn reset_params(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
let name = input.ident;

let mut custom_generics = input.generics.clone();
if custom_generics
.params
.iter()
.position(|param| match param {
syn::GenericParam::Type(type_param) if type_param.ident == "Elem" => true,
_ => false,
})
.is_none()
{
if !custom_generics.params.iter().any(
|param| matches!(param, syn::GenericParam::Type(type_param) if type_param.ident == "Elem"),
) {
custom_generics
.params
.push(parse_quote!(Elem: dfdx::prelude::Dtype));
}

if custom_generics
.params
.iter()
.position(|param| match param {
syn::GenericParam::Type(type_param) if type_param.ident == "Dev" => true,
_ => false,
})
.is_none()
{
if !custom_generics.params.iter().any(
|param| matches!(param, syn::GenericParam::Type(type_param) if type_param.ident == "Dev"),
) {
custom_generics
.params
.push(parse_quote!(Dev: dfdx::prelude::Device<Elem>));
Expand Down Expand Up @@ -631,29 +619,17 @@ pub fn update_params(input: proc_macro::TokenStream) -> proc_macro::TokenStream
let struct_name = input.ident;

let mut custom_generics = input.generics.clone();
if custom_generics
.params
.iter()
.position(|param| match param {
syn::GenericParam::Type(type_param) if type_param.ident == "Elem" => true,
_ => false,
})
.is_none()
{
if !custom_generics.params.iter().any(
|param| matches!(param, syn::GenericParam::Type(type_param) if type_param.ident == "Elem"),
) {
custom_generics
.params
.push(parse_quote!(Elem: dfdx::prelude::Dtype));
}

if custom_generics
.params
.iter()
.position(|param| match param {
syn::GenericParam::Type(type_param) if type_param.ident == "Dev" => true,
_ => false,
})
.is_none()
{
if !custom_generics.params.iter().any(
|param| matches!(param, syn::GenericParam::Type(type_param) if type_param.ident == "Dev"),
) {
custom_generics
.params
.push(parse_quote!(Dev: dfdx::prelude::Device<Elem>));
Expand Down Expand Up @@ -727,29 +703,17 @@ pub fn zero_grads(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
let name = input.ident;

let mut custom_generics = input.generics.clone();
if custom_generics
.params
.iter()
.position(|param| match param {
syn::GenericParam::Type(type_param) if type_param.ident == "Elem" => true,
_ => false,
})
.is_none()
{
if !custom_generics.params.iter().any(
|param| matches!(param, syn::GenericParam::Type(type_param) if type_param.ident == "Elem"),
) {
custom_generics
.params
.push(parse_quote!(Elem: dfdx::prelude::Dtype));
}

if custom_generics
.params
.iter()
.position(|param| match param {
syn::GenericParam::Type(type_param) if type_param.ident == "Dev" => true,
_ => false,
})
.is_none()
{
if !custom_generics.params.iter().any(
|param| matches!(param, syn::GenericParam::Type(type_param) if type_param.ident == "Dev"),
) {
custom_generics
.params
.push(parse_quote!(Dev: dfdx::prelude::Device<Elem>));
Expand Down
2 changes: 2 additions & 0 deletions dfdx-nn/src/layers/add_into.rs
Original file line number Diff line number Diff line change
Expand Up @@ -59,13 +59,15 @@ macro_rules! add_into_impls {
type Output = Out;
type Error = A::Error;

#[allow(clippy::needless_question_mark)]
fn try_forward(&self, x: (Ai, $($Inp, )+)) -> Result<Self::Output, Self::Error> {
let (a, $($ModVar, )+) = &self.0;
let (a_i, $($InpVar, )+) = x;
let a_i = a.try_forward(a_i)?;
$(let $InpVar = $ModVar.try_forward($InpVar)?;)+
Ok(sum!(a_i, $($InpVar),*))
}
#[allow(clippy::needless_question_mark)]
fn try_forward_mut(&mut self, x: (Ai, $($Inp, )+)) -> Result<Self::Output, Self::Error> {
let (a, $($ModVar, )+) = &mut self.0;
let (a_i, $($InpVar, )+) = x;
Expand Down
1 change: 1 addition & 0 deletions dfdx-nn/src/layers/conv1d.rs
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,7 @@ where
{
#[param]
#[serialize]
#[allow(clippy::type_complexity)]
pub weight: Tensor<
(
OutChan,
Expand Down
1 change: 1 addition & 0 deletions dfdx-nn/src/layers/conv2d.rs
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,7 @@ where
{
#[param]
#[serialize]
#[allow(clippy::type_complexity)]
pub weight: Tensor<
(
OutChan,
Expand Down
1 change: 1 addition & 0 deletions dfdx-nn/src/layers/conv_trans2d.rs
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,7 @@ where
{
#[param]
#[serialize]
#[allow(clippy::type_complexity)]
pub weight: Tensor<
(
InChan,
Expand Down

0 comments on commit 239c21d

Please sign in to comment.