Skip to content

Commit

Permalink
added interface for saving gauge back to host
Browse files Browse the repository at this point in the history
  • Loading branch information
aniketsen committed Feb 6, 2024
1 parent 3d3c0b8 commit 37dc1ab
Show file tree
Hide file tree
Showing 2 changed files with 119 additions and 0 deletions.
118 changes: 118 additions & 0 deletions quda_interface.c
Original file line number Diff line number Diff line change
Expand Up @@ -584,6 +584,86 @@ void reorder_gauge_toQuda( const su3 ** const gaugefield, const CompressionType
tm_stopwatch_pop(&g_timers, 0, 0, "TM_QUDA");
}

void reorder_gauge_fromQuda( const su3 ** const gaugefield, const CompressionType compression ) {
tm_stopwatch_push(&g_timers, __func__, "");

#ifdef TM_USE_OMP
#pragma omp parallel
{
#endif
_Complex double tmpcplx;

size_t gSize = (gauge_param.cpu_prec == QUDA_DOUBLE_PRECISION) ? sizeof(double) : sizeof(float);

// now copy and reorder
#ifdef TM_USE_OMP
#pragma omp for collapse(4)
#endif
for( int x0=0; x0<T; x0++ )
for( int x1=0; x1<LX; x1++ )
for( int x2=0; x2<LY; x2++ )
for( int x3=0; x3<LZ; x3++ ) {
#if USE_LZ_LY_LX_T
int j = x3 + LZ*x2 + LY*LZ*x1 + LX*LY*LZ*x0;
int tm_idx = x1 + LX*x2 + LY*LX*x3 + LZ*LY*LX*x0;
#else
int j = x1 + LX*x2 + LY*LX*x3 + LZ*LY*LX*x0;
int tm_idx = x3 + LZ*x2 + LY*LZ*x1 + LX*LY*LZ*x0;
#endif
int oddBit = (x0+x1+x2+x3) & 1;
int quda_idx = 18*(oddBit*VOLUME/2+j/2);

if( compression == NO_COMPRESSION && quda_input.fermionbc == TM_QUDA_THETABC ) {
// apply theta boundary conditions if compression is not used
for( int i=0; i<9; i++ ) {
tmpcplx = gauge_quda[0][quda_idx+2*i] + I*gauge_quda[0][quda_idx+2*i+1];
tmpcplx *= -g_kappa/phase_1;
gauge_quda[0][quda_idx+2*i] = creal(tmpcplx);
gauge_quda[0][quda_idx+2*i+1] = cimag(tmpcplx);

tmpcplx = gauge_quda[1][quda_idx+2*i] + I*gauge_quda[1][quda_idx+2*i+1];
tmpcplx *= -g_kappa/phase_2;
gauge_quda[1][quda_idx+2*i] = creal(tmpcplx);
gauge_quda[1][quda_idx+2*i+1] = cimag(tmpcplx);

tmpcplx = gauge_quda[2][quda_idx+2*i] + I*gauge_quda[2][quda_idx+2*i+1];
tmpcplx *= -g_kappa/phase_3;
gauge_quda[2][quda_idx+2*i] = creal(tmpcplx);
gauge_quda[2][quda_idx+2*i+1] = cimag(tmpcplx);

tmpcplx = gauge_quda[3][quda_idx+2*i] + I*gauge_quda[3][quda_idx+2*i+1];
tmpcplx *= -g_kappa/phase_0;
gauge_quda[3][quda_idx+2*i] = creal(tmpcplx);
gauge_quda[3][quda_idx+2*i+1] = cimag(tmpcplx);
}
// when compression is not used, we can still force naive anti-periodic boundary conditions
} else {
if ( quda_input.fermionbc == TM_QUDA_APBC && x0+g_proc_coords[0]*T == g_nproc_t*T-1 ) {
for( int i=0; i<18; i++ ) {
gauge_quda[3][quda_idx+i] = -gauge_quda[3][quda_idx+i];
}
} // quda_input.fermionbc
} // if(compression & boundary conditions)

#if USE_LZ_LY_LX_T
memcpy( &(gaugefield[tm_idx][3]), &(gauge_quda[0][quda_idx]), 18*gSize);
memcpy( &(gaugefield[tm_idx][2]), &(gauge_quda[1][quda_idx]), 18*gSize);
memcpy( &(gaugefield[tm_idx][1]), &(gauge_quda[2][quda_idx]), 18*gSize);
memcpy( &(gaugefield[tm_idx][0]), &(gauge_quda[3][quda_idx]), 18*gSize);
#else
memcpy( &(gaugefield[tm_idx][1]), &(gauge_quda[0][quda_idx]), 18*gSize);
memcpy( &(gaugefield[tm_idx][2]), &(gauge_quda[1][quda_idx]), 18*gSize);
memcpy( &(gaugefield[tm_idx][3]), &(gauge_quda[2][quda_idx]), 18*gSize);
memcpy( &(gaugefield[tm_idx][0]), &(gauge_quda[3][quda_idx]), 18*gSize);
#endif
} // volume loop
#ifdef TM_USE_OMP
} // OpenMP parallel closing brace
#endif

tm_stopwatch_pop(&g_timers, 0, 0, "TM_QUDA");
} /* reorder_gauge_fromQuda */

void _loadGaugeQuda( const CompressionType compression ) {
static int first_call = 1;
// check if the currently loaded gauge field is also the current gauge field
Expand Down Expand Up @@ -624,6 +704,44 @@ void _loadGaugeQuda( const CompressionType compression ) {
set_quda_gauge_state(&quda_gauge_state, g_gauge_state.gauge_id, X1, X2, X3, X0, &gauge_param);
}

void _saveGaugeQuda( const su3 ** const gaugefield, const int savegaugetype, const CompressionType compression ) {

if( !quda_initialized ) {
if(g_proc_id == 0) {
fprintf(stderr, "Error: QUDA must be initialized to call _saveGaugeQuda\n");
exit(2);
}
}

if( !quda_gauge_state.loaded ) {
if(g_proc_id == 0) {
fprintf(stderr, "Error: gauge must be loaded in QUDA\n");
exit(2);
}
}

QudaGaugeParam savegauge_param = newQudaGaugeParam();
savegauge_param = gauge_param;
savegauge_param.location = QUDA_CPU_FIELD_LOCATION;

if(savegaugetype == 0) {
savegauge_param.type = QUDA_WILSON_LINKS;
tm_debug_printf(0, 1, "# TM_QUDA: Called _saveGaugeQuda for gauge type: QUDA_WILSON_LINKS\n");
} else if(savegaugetype == 1) {
savegauge_param.type = QUDA_SMEARED_LINKS;
tm_debug_printf(0, 1, "# TM_QUDA: Called _saveGaugeQuda for gauge type: QUDA_SMEARED_LINKS\n");
} else {
fprintf(stderr, "Error: Invalid gauge type\n");
exit(2);
}

tm_stopwatch_push(&g_timers, "saveGaugeQuda", "");
saveGaugeQuda((void *)gauge_quda, &savegauge_param);
tm_stopwatch_pop(&g_timers, 0, 0, "TM_QUDA");
reorder_gauge_fromQuda(gaugefield, compression);

}

// reorder spinor to QUDA format
void reorder_spinor_toQuda( double* sp, QudaPrecision precision, int doublet ) {
tm_stopwatch_push(&g_timers, __func__, "");
Expand Down
1 change: 1 addition & 0 deletions quda_interface.h
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,7 @@ void _initQuda();
void _endQuda();
void _loadGaugeQuda(const CompressionType);
void _loadCloverQuda(QudaInvertParam * inv_param);
void _saveGaugeQuda( const su3 ** const gaugefield, const int savegaugetype, const CompressionType compression );

// direct line to QUDA inverter, no messing about with even/odd reordering
// source and propagator Should be full VOLUME spinor fields
Expand Down

0 comments on commit 37dc1ab

Please sign in to comment.