Skip to content

Commit

Permalink
Demonstration - BroadcastShape<Self> as "supertrait" for Dimension
Browse files Browse the repository at this point in the history
  • Loading branch information
bluss committed Jan 23, 2021
1 parent 979d6df commit 54c3a85
Show file tree
Hide file tree
Showing 3 changed files with 7 additions and 4 deletions.
5 changes: 3 additions & 2 deletions src/dimension/broadcast.rs
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ where
Ok(out)
}

pub trait BroadcastShape<Other: Dimension>: Dimension {
pub trait BroadcastShape<Other: Dimension> {
/// The resulting dimension type after broadcasting.
type BroadcastOutput: Dimension;

Expand All @@ -52,7 +52,8 @@ pub trait BroadcastShape<Other: Dimension>: Dimension {
/// Uses the [NumPy broadcasting rules]
/// (https://docs.scipy.org/doc/numpy/user/basics.broadcasting.html#general-broadcasting-rules).
fn broadcast_shape(&self, other: &Other) -> Result<Self::BroadcastOutput, ShapeError> {
broadcast_shape::<Self, Other, Self::BroadcastOutput>(self, other)
panic!()
//broadcast_shape::<Self, Other, Self::BroadcastOutput>(self, other)
}
}

Expand Down
3 changes: 3 additions & 0 deletions src/dimension/dimension_trait.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ use crate::IntoDimension;
use crate::RemoveAxis;
use crate::{ArrayView1, ArrayViewMut1};
use crate::{Dim, Ix, Ix0, Ix1, Ix2, Ix3, Ix4, Ix5, Ix6, IxDyn, IxDynImpl, Ixs, SliceOrIndex};
use crate::dimension::broadcast::BroadcastShape;

/// Array shape and index trait.
///
Expand All @@ -46,6 +47,8 @@ pub trait Dimension:
+ MulAssign
+ for<'x> MulAssign<&'x Self>
+ MulAssign<usize>
+ BroadcastShape<Self, BroadcastOutput=Self>
+ BroadcastShape<Ix0, BroadcastOutput=Self>
{
/// For fixed-size dimension representations (e.g. `Ix2`), this should be
/// `Some(ndim)`, and for variable-size dimension representations (e.g.
Expand Down
3 changes: 1 addition & 2 deletions src/numeric/impl_numeric.rs
Original file line number Diff line number Diff line change
Expand Up @@ -283,11 +283,10 @@ where
/// a.mean_axis(Axis(0)).unwrap().mean_axis(Axis(0)).unwrap() == aview0(&3.5)
/// );
/// ```
pub fn mean_axis(&self, axis: Axis) -> Option<Array<A, <D::Smaller as BroadcastShape<Ix0>>::BroadcastOutput>>
pub fn mean_axis(&self, axis: Axis) -> Option<Array<A, D::Smaller>>
where
A: Clone + Zero + FromPrimitive + Add<Output = A> + Div<Output = A>,
D: RemoveAxis,
D::Smaller: BroadcastShape<Ix0>,
{
let axis_length = self.len_of(axis);
if axis_length == 0 {
Expand Down

0 comments on commit 54c3a85

Please sign in to comment.