From f8f5c269671c9ef9671a45adb016f2b32634e105 Mon Sep 17 00:00:00 2001 From: Zhang Zhimin Date: Wed, 31 Jan 2018 21:46:25 +0800 Subject: [PATCH 1/2] fix slice --- matazure/common.hpp | 10 +++++----- matazure/point.hpp | 2 +- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/matazure/common.hpp b/matazure/common.hpp index 5738bc1..eaa1b24 100644 --- a/matazure/common.hpp +++ b/matazure/common.hpp @@ -1,4 +1,4 @@ -#pragma once +#pragma once #include #include @@ -502,11 +502,11 @@ inline auto slice(_Tensor ts, int_t positon_index)->decltype(make_lambda(interna } /// special for slice(tensor<_T, rank>, position_index), it produces a tensor<_T, rank-1> -template -inline auto slice(tensor<_T, _Rank, _Layout> ts, int_t positon_index, enable_if_t<_DimIdx == _Rank-1>* = nullptr)->tensor<_T, _Rank-1, _Layout>{ - const auto slice_ext = internal::slice_point<_DimIdx>(ts.shape()); +template +inline auto dense_slice(tensor<_T, _Rank, first_major_layout<_Rank>> ts, int_t positon_index)->tensor<_T, _Rank-1, first_major_layout<_Rank-1>>{ + const auto slice_ext = internal::slice_point<_Rank-1>(ts.shape()); auto slice_size = cumulative_prod(slice_ext)[_Rank-1]; - tensor<_T, _Rank-1, _Layout> ts_re(slice_ext, shared_ptr<_T>(ts.shared_data().get() + positon_index * slice_size, [ts](_T *){ })); + tensor<_T, _Rank-1, first_major_layout<_Rank-1>> ts_re(slice_ext, shared_ptr<_T>(ts.shared_data().get() + positon_index * slice_size, [ts](_T *){ })); return ts_re; } diff --git a/matazure/point.hpp b/matazure/point.hpp index 9919dc0..7fbfe95 100644 --- a/matazure/point.hpp +++ b/matazure/point.hpp @@ -1,6 +1,6 @@ #pragma once -#include +#include namespace matazure { From dddcf587bbb42742ac1268b4c266f9087b7b7b6a Mon Sep 17 00:00:00 2001 From: Zhang Zhimin Date: Thu, 1 Feb 2018 15:57:07 +0800 Subject: [PATCH 2/2] disable tensor only support pod value_type --- matazure/point.hpp | 2 +- matazure/tensor.hpp | 37 ++++++++++++++++++++++++++----------- test/main.cpp | 1 - 3 files changed, 27 insertions(+), 13 deletions(-) diff --git a/matazure/point.hpp b/matazure/point.hpp index 7fbfe95..4ec9c2b 100644 --- a/matazure/point.hpp +++ b/matazure/point.hpp @@ -1,4 +1,4 @@ -#pragma once +#pragma once #include diff --git a/matazure/tensor.hpp b/matazure/tensor.hpp index cd3e97b..8d165f0 100644 --- a/matazure/tensor.hpp +++ b/matazure/tensor.hpp @@ -1,4 +1,4 @@ -/** +/** * Defines tensor classes of host end */ @@ -38,6 +38,21 @@ class first_major_layout{ }); } + first_major_layout(const first_major_layout &rhs) : + first_major_layout(rhs.shape()) + { } + + first_major_layout & operator=(const first_major_layout &rhs) { + shape_ = rhs.shape(); + stride_ = get_stride(shape_); + + matazure::for_each(shape_, [](int_t b) { + if (b < 0) throw invalid_shape{}; + }); + + return *this; + } + MATAZURE_GENERAL int_t index2offset(const pointi &id) const { int_t offset = id[0]; for (int_t i = 1; i < rank; ++i) { @@ -78,8 +93,8 @@ class first_major_layout{ } private: - const pointi shape_; - const pointi stride_; + pointi shape_; + pointi stride_; }; template @@ -453,7 +468,7 @@ struct is_tensor> : bool_constant {}; template > class tensor : public tensor_expression> { public: - static_assert(std::is_pod<_ValueType>::value, "only supports pod type now"); + //static_assert(std::is_pod<_ValueType>::value, "only supports pod type now"); /// the rank of tensor static const int_t rank = _Rank; /** @@ -662,10 +677,10 @@ class tensor : public tensor_expression> { #endif public: - const pointi extent_; - const layout_type layout_; - const shared_ptr sp_data_; - value_type * const data_; + pointi extent_; + layout_type layout_; + shared_ptr sp_data_; + value_type * data_; }; using column_major_layout = first_major_layout<2>; @@ -675,9 +690,9 @@ using row_major_layout = last_major_layout<2>; template using matrix = tensor<_ValueType, 2, _Layout>; -/// alias of tensor <_ValueType, 1> -template > -using vector = tensor<_ValueType, 1, _Layout>; +///// alias of tensor <_ValueType, 1> +//template > +//using vector = tensor<_ValueType, 1, _Layout>; /// alias of tensor, _BlockDim::size(), _Layout> template > diff --git a/test/main.cpp b/test/main.cpp index 32d8e83..3550b92 100644 --- a/test/main.cpp +++ b/test/main.cpp @@ -3,5 +3,4 @@ int main(int argc, char * argv[]){ testing::InitGoogleTest(&argc, argv); return RUN_ALL_TESTS(); - }