Source code for ptp.datasets

"""Dataset manager
"""
import json
import logging
import os
import shutil
import subprocess

import requests

from ptp import util, docs

logger = logging.getLogger(__name__)


[docs]class Datasets(): """Dataset manager""" def __init__(self): self._set_paths() self._check_cfg() self.cfg = self._load_cfg() if (self.cfg is None): logger.error("Failed to load dataset server configurations") self.api_url = 'https://ptp.database.lasseufpa.org/api/' def _set_paths(self): """Define paths to save the configuration file""" this_file = os.path.realpath(__file__) rootdir = os.path.dirname(os.path.dirname(this_file)) self.local_repo = os.path.join(rootdir, "data") home = os.path.expanduser("~") self.cfg_path = os.path.join(home, ".ptp") self.cfg_file = os.path.join(self.cfg_path, "config.json") # Create local repo if it does not exist if not os.path.isdir(self.local_repo): os.makedirs(self.local_repo) def _get_all_ds_variations(self, dataset): """Get all possible variations to the dataset name""" self.ds_name = os.path.basename(dataset) no_ext_ds_name = os.path.splitext(self.ds_name)[0] ds_prefix = no_ext_ds_name.replace("-comp", "") ds_suffixes = [ "-comp.xz", "-comp.pbz2", "-comp.gz", "-comp.pickle", "-comp.json", ".json" ] all_ds_names = [ds_prefix + suffix for suffix in ds_suffixes] all_local_paths = [ os.path.join(self.local_repo, d) for d in all_ds_names ] return all_local_paths, all_ds_names def _check_cfg(self): """Check if path to cfg folder exists or create it otherwise""" if (not os.path.exists(self.cfg_path)): os.mkdir(self.cfg_path) elif (not os.path.isdir(self.cfg_path)): raise IsADirectoryError( "{} already exists, but is not a directory".format( self.cfg_path)) def _load_cfg(self): """Load user credentials from configuration file""" if (os.path.exists(self.cfg_file)): with open(self.cfg_file) as fd: cfg = json.load(fd) logger.info("Loaded dataset server configurations from {}".format( self.cfg_file)) else: logger.info("Couldn't find access information for dataset server.") cfg = self._create_cfg() return cfg def _copy_key_cert(self, key, cert): """Copy key and digital certificate into the config directory""" certs_dir = os.path.join(self.cfg_path, "certs") # Create certs directory if it does not exist if not os.path.isdir(certs_dir): os.makedirs(certs_dir) # Copy key and cert to the certs directory new_key_path = shutil.copy(os.path.expanduser(key), certs_dir) new_crt_path = shutil.copy(os.path.expanduser(cert), certs_dir) return new_key_path, new_crt_path def _create_cfg(self): """Create configuration file with user-provided credentials""" if (not util.ask_yes_or_no("Provide information now?")): return cfg = list() more = True while (more): dl_mode = input("Download via API or SSH? (API) ") or "API" if (dl_mode.upper() == 'SSH'): server = input("IP address of the dataset server: ") path = input("Path to dataset repository on server: ") user = input("Username to access the server: ") cfg.append({ 'dl_mode': 'SSH', 'addr': server, 'path': path, 'user': user }) elif (dl_mode.upper() == 'API'): ssl_key_in = input("Path to your SSL key: ") ssl_crt_in = input("Path to your SSL certificate: ") ssl_key, ssl_crt = self._copy_key_cert(ssl_key_in, ssl_crt_in) cfg.append({ 'dl_mode': 'API', 'ssl_key': ssl_key, 'ssl_crt': ssl_crt }) else: raise ValueError( "Download mode {} not defined".format(dl_mode)) more = util.ask_yes_or_no("Add another address?") with open(self.cfg_file, 'w') as fd: json.dump(cfg, fd) logger.info(f"Saved dataset server configurations on {self.cfg_file}") return cfg def _download_ssh(self, cfg, ds_name): """Download dataset via SSH from dataset server Args: cfg : Configuration file with user and server information ds_name : Dataset file name Return: Path to the file that was downloaded. None if not found. """ ds_repo = cfg['user'] + "@" + cfg['addr'] + ":" + cfg['path'] scp_src = os.path.join(ds_repo, ds_name) cmd = ["scp", scp_src, "data/"] ds_path = None logger.info("Trying %s" % (scp_src)) res = subprocess.run(cmd, timeout=60.0, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL) if (res.returncode == 0): print("Downloaded {} from {}".format(ds_name, ds_repo)) ds_path = os.path.join("data", ds_name) else: logger.debug("Couldn't find file {} in {}".format( ds_name, ds_repo)) return ds_path def _download_api(self, cfg, ds_name): """Download dataset via RESTful API Args: cfg : Configuration file with user and server information ds_name : Dataset file name Return: Path to the file that was downloaded. None if not found. """ addr = self.api_url + 'dataset' ds_req = os.path.join(addr, ds_name) logger.info("Trying " + ds_req) found = False ds_path = None try: cert = (cfg['ssl_crt'], cfg['ssl_key']) req = requests.get(ds_req, cert=cert, timeout=60.0) req.raise_for_status() local_ds_path = os.path.join(self.local_repo, ds_name) open(local_ds_path, 'wb').write(req.content) found = True except requests.exceptions.RequestException: pass if (found): print("Downloaded {} from {}".format(ds_name, addr)) ds_path = local_ds_path else: logger.debug("Couldn't find file {} in {}".format(ds_name, addr)) return ds_path
[docs] def download(self, dataset): """Download dataset from dataset server Args: dataset : dataset name/path Returns: Path to dataset file """ all_local_paths, all_ds_names = self._get_all_ds_variations(dataset) # Does the dataset exist on the local repository already? for path in all_local_paths: if (os.path.exists(path)): logger.info("Dataset already available locally") return path print("Dataset not available locally. Try to download from server.") # Try to load in order of compression (most compressed first) for entry in self.cfg: dl_mode = entry['dl_mode'] for ds_name in all_ds_names: if dl_mode == 'SSH': ds_path = self._download_ssh(entry, ds_name) else: ds_path = self._download_api(entry, ds_name) if (ds_path is not None): break if (ds_path is not None): break # Add to local catalog if (ds_path is not None): catalog = docs.Docs() catalog.add_dataset(ds_path) return ds_path # We should have returned before if a datset was found on the local or # remote repositories raise RuntimeError("Couldn't find dataset")
[docs] def search(self, parameters): """Search datasets via RESTful API Args: parameters : Dictionary with the query parameters Return: List with the founded datasets """ addr = self.api_url + 'search' headers = {'content-type': 'application/json'} ds_found = None if (self.cfg is None): return api_connections = [e for e in self.cfg if (e['dl_mode'] == 'API')] if (len(api_connections) == 0): logger.error( "Couldn't find a dataset server in your configuration") return for conn in api_connections: cert = (conn['ssl_crt'], conn['ssl_key']) try: req = requests.post(addr, data=json.dumps(parameters), headers=headers, cert=cert) req.raise_for_status() response = req.json() ds_found = response['found'] except requests.exceptions.RequestException as e: if (req.status_code == 400): logger.info("Bad request! Check your cfg file.") elif (req.status_code == 404): logger.info("No dataset found!") else: logger.info(e) pass return ds_found