Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[DO NOT MERGE] Tpetra: Initial support for execution_space instances in MultiVector #13160

Open
wants to merge 14 commits into
base: develop
Choose a base branch
from
187 changes: 186 additions & 1 deletion packages/tpetra/core/src/Tpetra_Details_WrappedDualView.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -98,28 +98,55 @@ using enableIfConstData = std::enable_if_t<hasConstData<DualViewType>::value>;
template <typename DualViewType>
using enableIfNonConstData = std::enable_if_t<!hasConstData<DualViewType>::value>;

/* sync_host functions */

template <typename DualViewType>
enableIfNonConstData<DualViewType>
sync_host(DualViewType dualView) {
// This will sync, but only if needed
dualView.sync_host();
}

template <typename DualViewType>
enableIfNonConstData<DualViewType>
sync_host(const typename DualViewType::t_host::execution_space& exec, DualViewType dualView) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
sync_host(const typename DualViewType::t_host::execution_space& exec, DualViewType dualView) {
sync_host(const typename DualViewType::t_host::execution_space& exec, DualViewType& dualView) {

(Not sure why you want to take by value in this case...)

// This will sync, but only if needed
dualView.sync_host();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
dualView.sync_host();
dualView.sync_host(exec);

}

template <typename DualViewType>
enableIfConstData<DualViewType>
sync_host(DualViewType dualView) { }

template <typename DualViewType>
enableIfConstData<DualViewType>
sync_host(const typename DualViewType::t_host::execution_space& exec, DualViewType dualView) { }
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

According to https://github.com/kokkos/kokkos/blob/2d7715239700f50169bc50a96a234b05c28c9a2e/containers/src/Kokkos_DualView.hpp#L721-L736, it seems that the Kokkos::DualView throws when calling sync_<host/device>. Is it intended that you change this behavior by making it a no-op ?


/* sync_device functions */

template <typename DualViewType>
enableIfNonConstData<DualViewType>
sync_device(DualViewType dualView) {
// This will sync, but only if needed
dualView.sync_device();
dualView.sync_device();
}

template <typename DualViewType>
enableIfNonConstData<DualViewType>
sync_device(const typename DualViewType::t_dev::execution_space& exec, DualViewType dualView) {
// This will sync, but only if needed
dualView.sync_device(exec);
}

template <typename DualViewType>
enableIfConstData<DualViewType>
sync_device(DualViewType dualView) { }

template <typename DualViewType>
enableIfConstData<DualViewType>
sync_device(const typename DualViewType::t_dev::execution_space& exec, DualViewType dualView) { }


}// end namespace Impl

/// \brief Whether WrappedDualView reference count checking is enabled. Initially true.
Expand Down Expand Up @@ -320,6 +347,19 @@ class WrappedDualView {
}
return dualView.view_device();
}

typename t_dev::const_type
getDeviceView(const typename DualViewType::t_dev::execution_space& exec, Access::ReadOnlyStruct
DEBUG_UVM_REMOVAL_ARGUMENT
) const
{
DEBUG_UVM_REMOVAL_PRINT_CALLER("getDeviceViewReadOnly");
if(needsSyncPath()) {
throwIfHostViewAlive();
impl::sync_device(exec, originalDualView);
}
return dualView.view_device();
}

t_dev
getDeviceView(Access::ReadWriteStruct
Expand All @@ -337,6 +377,23 @@ class WrappedDualView {
return dualView.view_device();
}

t_dev
getDeviceView(const typename DualViewType::t_dev::execution_space& exec, Access::ReadWriteStruct
DEBUG_UVM_REMOVAL_ARGUMENT
)
{
DEBUG_UVM_REMOVAL_PRINT_CALLER("getDeviceViewReadWrite");
static_assert(dualViewHasNonConstData,
"ReadWrite views are not available for DualView with const data");
if(needsSyncPath()) {
throwIfHostViewAlive();
impl::sync_device(exec,originalDualView);
originalDualView.modify_device();
}
return dualView.view_device();
}


t_dev
getDeviceView(Access::OverwriteAllStruct
DEBUG_UVM_REMOVAL_ARGUMENT
Expand All @@ -357,6 +414,21 @@ class WrappedDualView {
return dualView.view_device();
}


t_dev
getDeviceView(const typename DualViewType::t_dev::execution_space& exec, Access::OverwriteAllStruct s
DEBUG_UVM_REMOVAL_ARGUMENT
)
{
// Since we're never syncing in this case, the execution_space is meaningless here
#ifdef DEBUG_UVM_REMOVAL
return getDeviceView(s,callerstr,filestr,linnum);
#else
return getDeviceView(s);
#endif
}


template<class TargetDeviceType>
typename std::remove_reference<decltype(std::declval<DualViewType>().template view<TargetDeviceType>())>::type::const_type
getView (Access::ReadOnlyStruct s DEBUG_UVM_REMOVAL_ARGUMENT) const {
Expand All @@ -381,7 +453,31 @@ class WrappedDualView {
return dualView.template view<TargetDeviceType>();
}

template<class TargetDeviceType>
typename std::remove_reference<decltype(std::declval<DualViewType>().template view<TargetDeviceType>())>::type::const_type
getView (const typename TargetDeviceType::execution_space & exec, Access::ReadOnlyStruct s DEBUG_UVM_REMOVAL_ARGUMENT) const {
using ReturnViewType = typename std::remove_reference<decltype(std::declval<DualViewType>().template view<TargetDeviceType>())>::type::const_type;
using ReturnDeviceType = typename ReturnViewType::device_type;
constexpr bool returnDevice = std::is_same<ReturnDeviceType, DeviceType>::value;
if(returnDevice) {
DEBUG_UVM_REMOVAL_PRINT_CALLER("getView<Device>ReadOnly");
if(needsSyncPath()) {
throwIfHostViewAlive();
impl::sync_device(exec,originalDualView);
}
}
else {
DEBUG_UVM_REMOVAL_PRINT_CALLER("getView<Host>ReadOnly");
if(needsSyncPath()) {
throwIfDeviceViewAlive();
impl::sync_host(exec,originalDualView);
}
}

return dualView.template view<TargetDeviceType>();
}


template<class TargetDeviceType>
typename std::remove_reference<decltype(std::declval<DualViewType>().template view<TargetDeviceType>())>::type
getView (Access::ReadWriteStruct s DEBUG_UVM_REMOVAL_ARGUMENT) const {
Expand Down Expand Up @@ -414,6 +510,39 @@ class WrappedDualView {
}


template<class TargetDeviceType>
typename std::remove_reference<decltype(std::declval<DualViewType>().template view<TargetDeviceType>())>::type
getView (const typename TargetDeviceType::execution_space & exec,Access::ReadWriteStruct s DEBUG_UVM_REMOVAL_ARGUMENT) const {
using ReturnViewType = typename std::remove_reference<decltype(std::declval<DualViewType>().template view<TargetDeviceType>())>::type;
using ReturnDeviceType = typename ReturnViewType::device_type;
constexpr bool returnDevice = std::is_same<ReturnDeviceType, DeviceType>::value;

if(returnDevice) {
DEBUG_UVM_REMOVAL_PRINT_CALLER("getView<Device>ReadWrite");
static_assert(dualViewHasNonConstData,
"ReadWrite views are not available for DualView with const data");
if(needsSyncPath()) {
throwIfHostViewAlive();
impl::sync_device(exec,originalDualView);
originalDualView.modify_device();
}
}
else {
DEBUG_UVM_REMOVAL_PRINT_CALLER("getView<Host>ReadWrite");
static_assert(dualViewHasNonConstData,
"ReadWrite views are not available for DualView with const data");
if(needsSyncPath()) {
throwIfDeviceViewAlive();
impl::sync_host(exec,originalDualView);
originalDualView.modify_host();
}
}

return dualView.template view<TargetDeviceType>();
}



template<class TargetDeviceType>
typename std::remove_reference<decltype(std::declval<DualViewType>().template view<TargetDeviceType>())>::type
getView (Access::OverwriteAllStruct s DEBUG_UVM_REMOVAL_ARGUMENT) const {
Expand Down Expand Up @@ -450,6 +579,21 @@ class WrappedDualView {
}


template<class TargetDeviceType>
typename std::remove_reference<decltype(std::declval<DualViewType>().template view<TargetDeviceType>())>::type
getView (const typename TargetDeviceType::execution_space & exec, Access::OverwriteAllStruct s DEBUG_UVM_REMOVAL_ARGUMENT) const {
using ReturnViewType = typename std::remove_reference<decltype(std::declval<DualViewType>().template view<TargetDeviceType>())>::type;
using ReturnDeviceType = typename ReturnViewType::device_type;
// Since nothing syncs here, the ExecSpace is meaningless
#ifdef DEBUG_UVM_REMOVAL
return getView<TargetDeviceType>(s,callerstr,filestr,linnum);
#else
return getView<TargetDeviceType>(s);
#endif

}


typename t_host::const_type
getHostSubview(int offset, int numEntries, Access::ReadOnlyStruct
DEBUG_UVM_REMOVAL_ARGUMENT
Expand Down Expand Up @@ -503,6 +647,20 @@ class WrappedDualView {
return getSubview(dualView.view_device(), offset, numEntries);
}

typename t_dev::const_type
getDeviceSubview(const typename DualViewType::t_dev::execution_space& exec, int offset, int numEntries, Access::ReadOnlyStruct
DEBUG_UVM_REMOVAL_ARGUMENT
) const
{
DEBUG_UVM_REMOVAL_PRINT_CALLER("getDeviceSubviewReadOnly");
if(needsSyncPath()) {
throwIfHostViewAlive();
impl::sync_device(exec,originalDualView);
}
return getSubview(dualView.view_device(), offset, numEntries);
}


t_dev
getDeviceSubview(int offset, int numEntries, Access::ReadWriteStruct
DEBUG_UVM_REMOVAL_ARGUMENT
Expand All @@ -519,6 +677,22 @@ class WrappedDualView {
return getSubview(dualView.view_device(), offset, numEntries);
}

t_dev
getDeviceSubview(const typename DualViewType::t_dev::execution_space& exec, int offset, int numEntries, Access::ReadWriteStruct
DEBUG_UVM_REMOVAL_ARGUMENT
)
{
DEBUG_UVM_REMOVAL_PRINT_CALLER("getDeviceSubviewReadWrite");
static_assert(dualViewHasNonConstData,
"ReadWrite views are not available for DualView with const data");
if(needsSyncPath()) {
throwIfHostViewAlive();
impl::sync_device(exec, originalDualView);
originalDualView.modify_device();
}
return getSubview(dualView.view_device(), offset, numEntries);
}

t_dev
getDeviceSubview(int offset, int numEntries, Access::OverwriteAllStruct
DEBUG_UVM_REMOVAL_ARGUMENT
Expand All @@ -530,6 +704,17 @@ class WrappedDualView {
return getDeviceSubview(offset, numEntries, Access::ReadWrite);
}

t_dev
getDeviceSubview(const typename DualViewType::t_dev::execution_space& exec, int offset, int numEntries, Access::OverwriteAllStruct
DEBUG_UVM_REMOVAL_ARGUMENT
)
{
DEBUG_UVM_REMOVAL_PRINT_CALLER("getDeviceSubviewOverwriteAll");
static_assert(dualViewHasNonConstData,
"OverwriteAll views are not available for DualView with const data");
return getDeviceSubview(exec,offset, numEntries, Access::ReadWrite);
}


// Debugging functions to get copies of the view state
typename t_host::HostMirror getHostCopy() const {
Expand Down
Loading
Loading