# Licensed under a 3-clause BSD style license - see LICENSE.rst
from __future__ import (absolute_import, division, print_function,
unicode_literals)
import abc
import pickle
import hashlib
import os
import requests
from astropy.extern import six
from astropy.config import paths
from astropy import log
import astropy.units as u
from astropy.utils.console import ProgressBarOrSpinner
import astropy.utils.data
from . import version
__all__ = ['BaseQuery', 'QueryWithLogin']
def to_cache(response, cache_file):
log.debug("Caching data to {0}".format(cache_file))
with open(cache_file, "wb") as f:
pickle.dump(response, f)
def _replace_none_iterable(iterable):
return tuple('' if i is None else i for i in iterable)
class AstroQuery(object):
def __init__(self, method, url, params=None, data=None, headers=None,
files=None, timeout=None):
self.method = method
self.url = url
self.params = params
self.data = data
self.headers = headers
self.files = files
self._hash = None
self.timeout = timeout
@property
def timeout(self):
return self._timeout
@timeout.setter
def timeout(self, value):
if hasattr(value, 'to'):
self._timeout = value.to(u.s).value
else:
self._timeout = value
def request(self, session, cache_location=None, stream=False,
auth=None, verify=True):
return session.request(self.method, self.url, params=self.params,
data=self.data, headers=self.headers,
files=self.files, timeout=self.timeout,
stream=stream, auth=auth, verify=verify)
def hash(self):
if self._hash is None:
request_key = (self.method, self.url)
for k in (self.params, self.data, self.headers, self.files):
if isinstance(k, dict):
entry = (tuple(sorted(k.items(),
key=_replace_none_iterable)))
entry = tuple((k_, v_.read()) if hasattr(v_, 'read')
else (k_, v_) for k_, v_ in entry)
for k_, v_ in entry:
if hasattr(v_, 'read') and hasattr(v_, 'seek'):
v_.seek(0)
request_key += entry
elif isinstance(k, tuple) or isinstance(k, list):
request_key += (tuple(sorted(k,
key=_replace_none_iterable)),)
elif k is None:
request_key += (None,)
elif isinstance(k, six.string_types):
request_key += (k,)
else:
raise TypeError("{0} must be a dict, tuple, str, or "
"list".format(k))
self._hash = hashlib.sha224(pickle.dumps(request_key)).hexdigest()
return self._hash
def request_file(self, cache_location):
fn = os.path.join(cache_location, self.hash() + ".pickle")
return fn
def from_cache(self, cache_location):
request_file = self.request_file(cache_location)
try:
with open(request_file, "rb") as f:
response = pickle.load(f)
if not isinstance(response, requests.Response):
response = None
except IOError: # TODO: change to FileNotFoundError once drop py2 support
response = None
if response:
log.debug("Retrieving data from {0}".format(request_file))
return response
@six.add_metaclass(abc.ABCMeta)
[docs]class BaseQuery(object):
"""
This is the base class for all the query classes in astroquery. It
is implemented as an abstract class and must not be directly instantiated.
"""
def __init__(self):
S = self._session = requests.session()
S.headers['User-Agent'] = (
'astroquery/{vers} {olduseragent}'
.format(vers=version.version,
olduseragent=S.headers['User-Agent']))
self.cache_location = os.path.join(
paths.get_cache_dir(), 'astroquery',
self.__class__.__name__.split("Class")[0])
if not os.path.exists(self.cache_location):
os.makedirs(self.cache_location)
self._cache_active = True
[docs] def __call__(self, *args, **kwargs):
""" init a fresh copy of self """
return self.__class__(*args, **kwargs)
def _request(self, method, url, params=None, data=None, headers=None,
files=None, save=False, savedir='', timeout=None, cache=True,
stream=False, auth=None, continuation=True, verify=True):
"""
A generic HTTP request method, similar to `requests.Session.request`
but with added caching-related tools
This is a low-level method not generally intended for use by astroquery
end-users. As such, it is likely to be renamed to, e.g., `_request` in
the near future.
Parameters
----------
method : str
'GET' or 'POST'
url : str
params : None or dict
data : None or dict
headers : None or dict
auth : None or dict
files : None or dict
See `requests.request`
save : bool
Whether to save the file to a local directory. Caching will happen
independent of this parameter if `BaseQuery.cache_location` is set,
but the save location can be overridden if ``save==True``
savedir : str
The location to save the local file if you want to save it
somewhere other than `BaseQuery.cache_location`
timeout : int
cache : bool
verify : bool
continuation : bool
stream : bool
Returns
-------
response : `requests.Response`
The response from the server if ``save`` is False
local_filepath : list
a list of strings containing the downloaded local paths if ``save``
is True
"""
if save:
local_filename = url.split('/')[-1]
if os.name == 'nt':
# Windows doesn't allow special characters in filenames like
# ":" so replace them with an underscore
local_filename = local_filename.replace(':', '_')
local_filepath = os.path.join(self.cache_location or savedir or
'.', local_filename)
# REDUNDANT: spinner has this log.info("Downloading
# {0}...".format(local_filename))
self._download_file(url, local_filepath, timeout=timeout,
auth=auth, cache=cache,
continuation=continuation)
return local_filepath
else:
query = AstroQuery(method, url, params=params, data=data,
headers=headers, files=files, timeout=timeout)
if ((self.cache_location is None) or (not self._cache_active) or
(not cache)):
with suspend_cache(self):
response = query.request(self._session, stream=stream,
auth=auth, verify=verify)
else:
response = query.from_cache(self.cache_location)
if not response:
response = query.request(self._session,
self.cache_location,
stream=stream,
auth=auth,
verify=verify)
to_cache(response, query.request_file(self.cache_location))
self._last_query = query
return response
def _download_file(self, url, local_filepath, timeout=None, auth=None,
continuation=True, cache=False, **kwargs):
"""
Download a file. Resembles `astropy.utils.data.download_file` but uses
the local ``_session``
"""
response = self._session.get(url, timeout=timeout, stream=True,
auth=auth, **kwargs)
response.raise_for_status()
if 'content-length' in response.headers:
length = int(response.headers['content-length'])
else:
length = None
if ((os.path.exists(local_filepath) and ('Accept-Ranges' in
response.headers) and
continuation)):
open_mode = 'ab'
existing_file_length = os.stat(local_filepath).st_size
if length is not None and existing_file_length >= length:
# all done!
log.info("Found cached file {0} with expected size {1}."
.format(local_filepath, existing_file_length))
return
elif existing_file_length == 0:
open_mode = 'wb'
else:
log.info("Continuing download of file {0}, with {1} bytes to "
"go ({2}%)".format(local_filepath,
length - existing_file_length,
(length-existing_file_length)/length*100))
# bytes are indexed from 0:
# https://en.wikipedia.org/wiki/List_of_HTTP_header_fields#range-request-header
end = "{0}".format(length-1) if length is not None else ""
self._session.headers['Range'] = "bytes={0}-{1}".format(existing_file_length,
end)
response = self._session.get(url, timeout=timeout, stream=True,
auth=auth, **kwargs)
elif cache and os.path.exists(local_filepath):
if length is not None:
statinfo = os.stat(local_filepath)
if statinfo.st_size != length:
log.warning("Found cached file {0} with size {1} that is "
"different from expected size {2}"
.format(local_filepath,
statinfo.st_size,
length))
open_mode = 'wb'
else:
log.info("Found cached file {0} with expected size {1}."
.format(local_filepath, statinfo.st_size))
response.close()
return
else:
log.info("Found cached file {0}.".format(local_filepath))
response.close()
return
else:
open_mode = 'wb'
blocksize = astropy.utils.data.conf.download_block_size
bytes_read = 0
with ProgressBarOrSpinner(
length, ('Downloading URL {0} to {1} ...'
.format(url, local_filepath))) as pb:
with open(local_filepath, open_mode) as f:
for block in response.iter_content(blocksize):
f.write(block)
bytes_read += blocksize
if length is not None:
pb.update(bytes_read if bytes_read <= length else
length)
else:
pb.update(bytes_read)
response.close()
class suspend_cache:
"""
A context manager that suspends caching.
"""
def __init__(self, obj):
self.obj = obj
def __enter__(self):
self.obj._cache_active = False
def __exit__(self, exc_type, exc_value, traceback):
self.obj._cache_active = True
return False
[docs]class QueryWithLogin(BaseQuery):
"""
This is the base class for all the query classes which are required to
have a login to access the data.
The abstract method _login() must be implemented. It is wrapped by the
login() method, which turns off the cache. This way, login credentials
are not stored in the cache.
"""
def __init__(self):
super(QueryWithLogin, self).__init__()
self._authenticated = False
@abc.abstractmethod
def _login(self, *args, **kwargs):
"""
login to non-public data as a known user
Parameters
----------
Keyword arguments that can be used to create
the data payload(dict) sent via `requests.post`
"""
pass
[docs] def login(self, *args, **kwargs):
with suspend_cache(self):
self._authenticated = self._login(*args, **kwargs)
return self._authenticated
[docs] def authenticated(self):
return self._authenticated