Skip to content

Commit

Permalink
Reduce code duplication
Browse files Browse the repository at this point in the history
  • Loading branch information
garth-wells committed Sep 24, 2024
1 parent bfed4bd commit 5d5cf01
Show file tree
Hide file tree
Showing 2 changed files with 45 additions and 72 deletions.
97 changes: 42 additions & 55 deletions cpp/dolfinx/fem/DirichletBC.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
#include <dolfinx/common/types.h>
#include <functional>
#include <memory>
#include <optional>
#include <span>
#include <type_traits>
#include <utility>
Expand Down Expand Up @@ -469,18 +470,13 @@ class DirichletBC
return {_dofs0, _owned_indices0};
}

/// Set bc entries in `x` to `scale * x_bc`
/// @brief Set bc entries in `x` to `scale * (x0 - x_bc)`.
///
/// @param[in] x The array in which to set `scale * x_bc[i]`, where
/// x_bc[i] is the boundary value of x[i]. Entries in x that do not
/// have a Dirichlet condition applied to them are unchanged. The
/// length of x must be less than or equal to the index of the
/// greatest boundary dof index. To set values only for
/// degrees-of-freedom that are owned by the calling rank, the length
/// of the array @p x should be equal to the number of dofs owned by
/// this rank.
/// @param[in] x The array in which to set `scale * (x0 - x_bc)`
/// @param[in] x0 The array used in compute the value to set
/// @param[in] scale The scaling value to apply
void set(std::span<T> x, T scale = 1) const
void set(std::span<T> x, std::optional<std::span<const T>> x0,
T scale = 1) const
{
if (std::holds_alternative<std::shared_ptr<const Function<T, U>>>(_g))
{
Expand All @@ -489,51 +485,28 @@ class DirichletBC
std::span<const T> values = g->x()->array();
auto dofs1_g = _dofs1_g.empty() ? std::span(_dofs0) : std::span(_dofs1_g);
std::int32_t x_size = x.size();
for (std::size_t i = 0; i < _dofs0.size(); ++i)
if (x0.has_value())
{
if (_dofs0[i] < x_size)
std::span<const T> _x0 = x0.value();
assert(x.size() <= _x0.size());
for (std::size_t i = 0; i < _dofs0.size(); ++i)
{
assert(dofs1_g[i] < (std::int32_t)values.size());
x[_dofs0[i]] = scale * values[dofs1_g[i]];
if (_dofs0[i] < x_size)
{
assert(dofs1_g[i] < (std::int32_t)values.size());
x[_dofs0[i]] = scale * (values[dofs1_g[i]] - _x0[_dofs0[i]]);
}
}
}
}
else if (std::holds_alternative<std::shared_ptr<const Constant<T>>>(_g))
{
auto g = std::get<std::shared_ptr<const Constant<T>>>(_g);
std::vector<T> value = g->value;
int bs = _function_space->dofmap()->bs();
std::int32_t x_size = x.size();
std::ranges::for_each(_dofs0,
[x_size, bs, scale, &value, &x](auto dof)
{
if (dof < x_size)
x[dof] = scale * value[dof % bs];
});
}
}

/// @brief Set bc entries in `x` to `scale * (x0 - x_bc)`.
///
/// @param[in] x The array in which to set `scale * (x0 - x_bc)`
/// @param[in] x0 The array used in compute the value to set
/// @param[in] scale The scaling value to apply
void set(std::span<T> x, std::span<const T> x0, T scale = 1) const
{
if (std::holds_alternative<std::shared_ptr<const Function<T, U>>>(_g))
{
auto g = std::get<std::shared_ptr<const Function<T, U>>>(_g);
assert(g);
std::span<const T> values = g->x()->array();
assert(x.size() <= x0.size());
auto dofs1_g = _dofs1_g.empty() ? std::span(_dofs0) : std::span(_dofs1_g);
std::int32_t x_size = x.size();
for (std::size_t i = 0; i < _dofs0.size(); ++i)
else
{
if (_dofs0[i] < x_size)
for (std::size_t i = 0; i < _dofs0.size(); ++i)
{
assert(dofs1_g[i] < (std::int32_t)values.size());
x[_dofs0[i]] = scale * (values[dofs1_g[i]] - x0[_dofs0[i]]);
if (_dofs0[i] < x_size)
{
assert(dofs1_g[i] < (std::int32_t)values.size());
x[_dofs0[i]] = scale * values[dofs1_g[i]];
}
}
}
}
Expand All @@ -542,12 +515,26 @@ class DirichletBC
auto g = std::get<std::shared_ptr<const Constant<T>>>(_g);
const std::vector<T>& value = g->value;
std::int32_t bs = _function_space->dofmap()->bs();
std::ranges::for_each(_dofs0,
[&x, &x0, &value, scale, bs](auto dof)
{
if (dof < (std::int32_t)x.size())
x[dof] = scale * (value[dof % bs] - x0[dof]);
});
if (x0.has_value())
{
assert(x.size() <= x0.value().size());
std::ranges::for_each(_dofs0,
[&x, x0 = x0.value(), &value, scale, bs](auto dof)
{
if (dof < (std::int32_t)x.size())
x[dof] = scale * (value[dof % bs] - x0[dof]);
});
}
else
{
std::ranges::for_each(
_dofs0,
[x_size = x.size(), bs, scale, &value, &x](auto dof)
{
if (dof < x_size)
x[dof] = scale * value[dof % bs];
});
}
}
}

Expand Down
20 changes: 3 additions & 17 deletions cpp/dolfinx/fem/assembler.h
Original file line number Diff line number Diff line change
Expand Up @@ -421,24 +421,10 @@ void set_bc(std::span<T> b,
const std::vector<std::shared_ptr<const DirichletBC<T, U>>>& bcs,
std::optional<std::span<const T>> x0, T scale = 1)
{
if (x0.has_value())
{

if (b.size() > x0.value().size())
throw std::runtime_error("Size mismatch between b and x0 vectors.");
for (auto& bc : bcs)
{
assert(bc);
bc->set(b, x0.value(), scale);
}
}
else
for (auto& bc : bcs)
{
for (auto& bc : bcs)
{
assert(bc);
bc->set(b, scale);
}
assert(bc);
bc->set(b, x0.value(), scale);
}
}
} // namespace dolfinx::fem

0 comments on commit 5d5cf01

Please sign in to comment.