Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

added interface for saving gauge back to host #582

Merged
merged 1 commit into from
Mar 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
118 changes: 118 additions & 0 deletions quda_interface.c
Original file line number Diff line number Diff line change
Expand Up @@ -584,6 +584,86 @@ void reorder_gauge_toQuda( const su3 ** const gaugefield, const CompressionType
tm_stopwatch_pop(&g_timers, 0, 0, "TM_QUDA");
}

void reorder_gauge_fromQuda( const su3 ** const gaugefield, const CompressionType compression ) {
tm_stopwatch_push(&g_timers, __func__, "");

#ifdef TM_USE_OMP
#pragma omp parallel
{
#endif
_Complex double tmpcplx;

size_t gSize = (gauge_param.cpu_prec == QUDA_DOUBLE_PRECISION) ? sizeof(double) : sizeof(float);

// now copy and reorder
#ifdef TM_USE_OMP
#pragma omp for collapse(4)
#endif
for( int x0=0; x0<T; x0++ )
for( int x1=0; x1<LX; x1++ )
for( int x2=0; x2<LY; x2++ )
for( int x3=0; x3<LZ; x3++ ) {
#if USE_LZ_LY_LX_T
int j = x3 + LZ*x2 + LY*LZ*x1 + LX*LY*LZ*x0;
int tm_idx = x1 + LX*x2 + LY*LX*x3 + LZ*LY*LX*x0;
#else
int j = x1 + LX*x2 + LY*LX*x3 + LZ*LY*LX*x0;
int tm_idx = x3 + LZ*x2 + LY*LZ*x1 + LX*LY*LZ*x0;
#endif
int oddBit = (x0+x1+x2+x3) & 1;
int quda_idx = 18*(oddBit*VOLUME/2+j/2);

if( compression == NO_COMPRESSION && quda_input.fermionbc == TM_QUDA_THETABC ) {
// apply theta boundary conditions if compression is not used
for( int i=0; i<9; i++ ) {
tmpcplx = gauge_quda[0][quda_idx+2*i] + I*gauge_quda[0][quda_idx+2*i+1];
tmpcplx *= -g_kappa/phase_1;
kostrzewa marked this conversation as resolved.
Show resolved Hide resolved
gauge_quda[0][quda_idx+2*i] = creal(tmpcplx);
gauge_quda[0][quda_idx+2*i+1] = cimag(tmpcplx);

tmpcplx = gauge_quda[1][quda_idx+2*i] + I*gauge_quda[1][quda_idx+2*i+1];
tmpcplx *= -g_kappa/phase_2;
gauge_quda[1][quda_idx+2*i] = creal(tmpcplx);
gauge_quda[1][quda_idx+2*i+1] = cimag(tmpcplx);

tmpcplx = gauge_quda[2][quda_idx+2*i] + I*gauge_quda[2][quda_idx+2*i+1];
tmpcplx *= -g_kappa/phase_3;
gauge_quda[2][quda_idx+2*i] = creal(tmpcplx);
gauge_quda[2][quda_idx+2*i+1] = cimag(tmpcplx);

tmpcplx = gauge_quda[3][quda_idx+2*i] + I*gauge_quda[3][quda_idx+2*i+1];
tmpcplx *= -g_kappa/phase_0;
gauge_quda[3][quda_idx+2*i] = creal(tmpcplx);
gauge_quda[3][quda_idx+2*i+1] = cimag(tmpcplx);
}
// when compression is not used, we can still force naive anti-periodic boundary conditions
} else {
if ( quda_input.fermionbc == TM_QUDA_APBC && x0+g_proc_coords[0]*T == g_nproc_t*T-1 ) {
for( int i=0; i<18; i++ ) {
gauge_quda[3][quda_idx+i] = -gauge_quda[3][quda_idx+i];
}
} // quda_input.fermionbc
} // if(compression & boundary conditions)

#if USE_LZ_LY_LX_T
memcpy( &(gaugefield[tm_idx][3]), &(gauge_quda[0][quda_idx]), 18*gSize);
memcpy( &(gaugefield[tm_idx][2]), &(gauge_quda[1][quda_idx]), 18*gSize);
memcpy( &(gaugefield[tm_idx][1]), &(gauge_quda[2][quda_idx]), 18*gSize);
memcpy( &(gaugefield[tm_idx][0]), &(gauge_quda[3][quda_idx]), 18*gSize);
#else
memcpy( &(gaugefield[tm_idx][1]), &(gauge_quda[0][quda_idx]), 18*gSize);
memcpy( &(gaugefield[tm_idx][2]), &(gauge_quda[1][quda_idx]), 18*gSize);
memcpy( &(gaugefield[tm_idx][3]), &(gauge_quda[2][quda_idx]), 18*gSize);
memcpy( &(gaugefield[tm_idx][0]), &(gauge_quda[3][quda_idx]), 18*gSize);
#endif
} // volume loop
#ifdef TM_USE_OMP
} // OpenMP parallel closing brace
#endif

tm_stopwatch_pop(&g_timers, 0, 0, "TM_QUDA");
} /* reorder_gauge_fromQuda */

void _loadGaugeQuda( const CompressionType compression ) {
static int first_call = 1;
// check if the currently loaded gauge field is also the current gauge field
Expand Down Expand Up @@ -624,6 +704,44 @@ void _loadGaugeQuda( const CompressionType compression ) {
set_quda_gauge_state(&quda_gauge_state, g_gauge_state.gauge_id, X1, X2, X3, X0, &gauge_param);
}

void _saveGaugeQuda( const su3 ** const gaugefield, const int savegaugetype, const CompressionType compression ) {

if( !quda_initialized ) {
if(g_proc_id == 0) {
fprintf(stderr, "Error: QUDA must be initialized to call _saveGaugeQuda\n");
exit(2);
}
}

if( !quda_gauge_state.loaded ) {
if(g_proc_id == 0) {
fprintf(stderr, "Error: gauge must be loaded in QUDA\n");
exit(2);
}
}

QudaGaugeParam savegauge_param = newQudaGaugeParam();
savegauge_param = gauge_param;
savegauge_param.location = QUDA_CPU_FIELD_LOCATION;

if(savegaugetype == 0) {
savegauge_param.type = QUDA_WILSON_LINKS;
tm_debug_printf(0, 1, "# TM_QUDA: Called _saveGaugeQuda for gauge type: QUDA_WILSON_LINKS\n");
} else if(savegaugetype == 1) {
savegauge_param.type = QUDA_SMEARED_LINKS;
tm_debug_printf(0, 1, "# TM_QUDA: Called _saveGaugeQuda for gauge type: QUDA_SMEARED_LINKS\n");
} else {
fprintf(stderr, "Error: Invalid gauge type\n");
exit(2);
}

tm_stopwatch_push(&g_timers, "saveGaugeQuda", "");
saveGaugeQuda((void *)gauge_quda, &savegauge_param);
tm_stopwatch_pop(&g_timers, 0, 0, "TM_QUDA");
reorder_gauge_fromQuda(gaugefield, compression);

}

// reorder spinor to QUDA format
void reorder_spinor_toQuda( double* sp, QudaPrecision precision, int doublet ) {
tm_stopwatch_push(&g_timers, __func__, "");
Expand Down
1 change: 1 addition & 0 deletions quda_interface.h
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,7 @@ void _initQuda();
void _endQuda();
void _loadGaugeQuda(const CompressionType);
void _loadCloverQuda(QudaInvertParam * inv_param);
void _saveGaugeQuda( const su3 ** const gaugefield, const int savegaugetype, const CompressionType compression );

// direct line to QUDA inverter, no messing about with even/odd reordering
// source and propagator Should be full VOLUME spinor fields
Expand Down
Loading