Skip to content

Commit

Permalink
Added hash sum as dummy parameter for SpatialTemplate_2D (#70)
Browse files Browse the repository at this point in the history
* Added hash sum as dummy parameter for SpatialTemplate_2D

* Added test for SpatialTemplate_2D
  • Loading branch information
henrikef authored and giacomov committed Jun 19, 2018
1 parent 87b847f commit 1742826
Show file tree
Hide file tree
Showing 2 changed files with 87 additions and 5 deletions.
23 changes: 19 additions & 4 deletions astromodels/functions/functions_2D.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@
from astromodels.utils.angular_distance import angular_distance
from astromodels.utils.vincenty import vincenty

import hashlib


class Latitude_galactic_diffuse(Function2D):
r"""
Expand Down Expand Up @@ -534,6 +536,12 @@ class SpatialTemplate_2D(Function2D):
desc : normalization
initial value : 1
fix : yes
hash :
desc: hash of model map [needed for memoization]
initial value: 1
fix: yes
"""

Expand All @@ -556,14 +564,21 @@ def load_file(self,fitsfile,ihdu=0):

self._wcs = wcs.WCS( header = f[ihdu].header )
self._map = f[ihdu].data

self._nX = f[ihdu].header['NAXIS1']
self._nY = f[ihdu].header['NAXIS2']

#note: map coordinates are switched compared to header. NAXIS1 is coordinate 1, not 0.
#see http://docs.astropy.org/en/stable/io/fits/#working-with-image-data
assert self._map.shape[1] == self._nX, "NAXIS1 = %d in fits header, but %d in map" % (self._nX, self._map.shape[1])
assert self._map.shape[0] == self._nY, "NAXIS2 = %d in fits header, but %d in map" % (self._nY, self._map.shape[0])

#note: map coordinates are switched compared to header. NAXIS1 is coordinate 1, not 0.
#see http://docs.astropy.org/en/stable/io/fits/#working-with-image-data
#hash sum uniquely identifying the template function (defined by its 2D map array and coordinate system)
#this is needed so that the memoization won't confuse different SpatialTemplate_2D objects.
h = hashlib.sha224()
h.update( self._map)
h.update( repr(self._wcs) )
self.hash = int(h.hexdigest(), 16)


def set_frame(self, new_frame):
Expand All @@ -577,7 +592,7 @@ def set_frame(self, new_frame):

self._frame = new_frame

def evaluate(self, x, y, K):
def evaluate(self, x, y, K, hash):

# We assume x and y are R.A. and Dec
coord = SkyCoord(ra=x, dec=y, frame=self._frame, unit="deg")
Expand Down
69 changes: 68 additions & 1 deletion astromodels/tests/test_functions.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import pytest
import os

import astropy.units as u
import numpy as np
Expand All @@ -7,10 +8,12 @@
from astromodels.functions.function import FunctionMeta, Function1D, Function2D, FunctionDefinitionError, \
UnknownParameter, DesignViolation, get_function, get_function_class, UnknownFunction, list_functions
from astromodels.functions.functions import Powerlaw, Line
from astromodels.functions.functions_2D import Gaussian_on_sphere
from astromodels.functions.functions_2D import Gaussian_on_sphere, SpatialTemplate_2D
from astromodels.functions.functions_3D import Continuous_injection_diffusion
from astromodels.functions import function as function_module

from astropy.io import fits

__author__ = 'giacomov'


Expand Down Expand Up @@ -849,3 +852,67 @@ def test_function3D():
with pytest.raises(TypeError):

c.set_units("not existent", u.deg, u.keV, 1.0 / (u.keV * u.s * u.deg**2 * u.cm**2))

def test_spatial_template_2D():

#make the fits files with templates to test.
cards = {
"SIMPLE": "T",
"BITPIX": -32,
"NAXIS" : 2,
"NAXIS1": 360,
"NAXIS2": 360,
"DATE": '2018-06-15',
"CUNIT1": 'deg',
"CRVAL1": 83,
"CRPIX1": 0,
"CDELT1": -0.0166667,
"CUNIT2": 'deg',
"CRVAL2": -2.0,
"CRPIX2": 0,
"CDELT2": 0.0166667,
"CTYPE1": 'GLON-CAR',
"CTYPE2": 'GLAT-CAR' }

data = np.zeros([400,400])
data[0:100,0:100] = 1
hdu = fits.PrimaryHDU(data=data, header=fits.Header(cards))
hdu.writeto("test1.fits", overwrite=True)

data[:,:]=0
data[200:300,200:300] = 1
hdu = fits.PrimaryHDU(data=data, header=fits.Header(cards))
hdu.writeto("test2.fits", overwrite=True)


#Now load template files and test their evaluation
shape1=SpatialTemplate_2D()
shape1.load_file("test1.fits")
shape1.K = 1

shape2=SpatialTemplate_2D()
shape2.load_file("test2.fits")
shape2.K = 1

assert shape1.hash != shape2.hash

assert np.all ( shape1.evaluate( [312, 306], [41, 41], [1,1], [40, 2]) == [1., 0.] )
assert np.all ( shape2.evaluate( [312, 306], [41, 41], [1,1], [40, 2]) == [0., 1.] )
assert np.all ( shape1.evaluate( [312, 306], [41, 41], [1,10], [40, 2]) == [1., 0.] )
assert np.all ( shape2.evaluate( [312, 306], [41, 41], [1,10], [40, 2]) == [0., 10.] )


shape1.K = 1
shape2.K = 1
assert np.all ( shape1( [312, 306], [41, 41]) == [1., 0.] )
assert np.all ( shape2( [312, 306], [41, 41]) == [0., 1.] )

shape1.K = 1
shape2.K = 10
assert np.all ( shape1( [312, 306], [41, 41]) == [1., 0.] )
assert np.all ( shape2( [312, 306], [41, 41]) == [0., 10.] )

os.remove("test1.fits")
os.remove("test2.fits")


0 comments on commit 1742826

Please sign in to comment.