diff --git a/gui/slick/views/config_postProcessing.mako b/gui/slick/views/config_postProcessing.mako index 2435a9d2c8fe2c4299e4de025ee79f7ad4a52166..94a071e06f97608826b0a243f1751776b4afc779 100644 --- a/gui/slick/views/config_postProcessing.mako +++ b/gui/slick/views/config_postProcessing.mako @@ -158,6 +158,22 @@ <span class="component-desc"><b>NOTE:</b> Some systems may ignore this feature.</span> </label> </div> + <div class="field-pair"> + <label class="nocheck" for="file_timestamp_timezone"> + <span class="component-title">Timezone for File Date:</span> + <span class="component-desc"> + <select name="file_timestamp_timezone" id="file_timestamp_timezone" class="form-control input-sm"> + % for curTimezone in ('local','network'): + <option value="${curTimezone}" ${('', 'selected="selected"')[sickbeard.FILE_TIMESTAMP_TIMEZONE == curTimezone]}>${curTimezone}</option> + % endfor + </select> + </span> + </label> + <label class="nocheck"> + <span class="component-title"> </span> + <span class="component-desc">What timezone should be used to change File Date?</span> + </label> + </div> <div class="field-pair"> <input id="unpack" type="checkbox" name="unpack" ${('', 'checked="checked"')[bool(sickbeard.UNPACK)]} /> <label for="unpack"> diff --git a/gui/slick/views/config_search.mako b/gui/slick/views/config_search.mako index 8e74093c3c99014ca985b861de20a13ebf9d9f1d..44bf1b27420a29cc21a5134671e571ed5578ea4d 100644 --- a/gui/slick/views/config_search.mako +++ b/gui/slick/views/config_search.mako @@ -124,6 +124,19 @@ </label> </div> + <div class="field-pair"> + <label> + <span class="component-title">Ignore language names in subbed results</span> + <span class="component-desc"> + <input type="text" name="ignored_subs_list" value="${sickbeard.IGNORED_SUBS_LIST}" class="form-control input-sm input350" /> + <div class="clear-left">Ignore subbed releases based on language names <br /> + Example: "dk" will ignore words: dksub, dksubs, dksubbed, dksubed <br /> + separate languages with a comma, e.g. "lang1,lang2,lang3 + </div> + </span> + </label> + </div> + <div class="field-pair"> <label for="allow_high_priority"> <span class="component-title">Allow high priority</span> diff --git a/gui/slick/views/inc_qualityChooser.mako b/gui/slick/views/inc_qualityChooser.mako index 6faf7614bf421a2abad41ca39e618a525b804a1a..c2632960cbd0ba35b9451a1ff02c75084dbf758d 100644 --- a/gui/slick/views/inc_qualityChooser.mako +++ b/gui/slick/views/inc_qualityChooser.mako @@ -24,13 +24,11 @@ bestQualities = qualities[1] </select> <div id="customQualityWrapper"> - <div id="customQuality"> - <div class="component-group-desc"> - <p><b>Preferred</b> qualities will replace an <b>Allowed</b> quality if found, initially or in the future, even if it is a lower quality</p> - </div> + <div id="customQuality" style="padding-left: 0px;"> + <p><b><u>Preferred</u></b> quality's will replace those in <b><u>allowed</u></b>, even if they are lower.</p> <div style="padding-right: 40px; text-align: left; float: left;"> - <h5>Allowed</h4> + <h5>Allowed</h5> <% anyQualityList = filter(lambda x: x > Quality.NONE, Quality.qualityStrings) %> <select id="anyQualities" name="anyQualities" multiple="multiple" size="${len(anyQualityList)}" class="form-control form-control-inline input-sm"> % for curQuality in sorted(anyQualityList): @@ -40,7 +38,7 @@ bestQualities = qualities[1] </div> <div style="text-align: left; float: left;"> - <h5>Preferred</h4> + <h5>Preferred</h5> <% bestQualityList = filter(lambda x: x >= Quality.SDTV and x < Quality.UNKNOWN, Quality.qualityStrings) %> <select id="bestQualities" name="bestQualities" multiple="multiple" size="${len(bestQualityList)}" class="form-control form-control-inline input-sm"> % for curQuality in sorted(bestQualityList): diff --git a/gui/slick/views/status.mako b/gui/slick/views/status.mako index 0dbcef873b13540a4918dd24780a2cb8f2eda428..9162b62ae9314a18a10554983c4dca2e7169a996 100644 --- a/gui/slick/views/status.mako +++ b/gui/slick/views/status.mako @@ -203,9 +203,9 @@ <td>TV Download Directory</td> <td>${sickbeard.TV_DOWNLOAD_DIR}</td> % if tvdirFree is not False: - <td>${tvdirFree} MB</td> + <td align="middle">${tvdirFree}</td> % else: - <td><i>Missing</i></td> + <td align="middle"><i>Missing</i></td> % endif </tr> % endif @@ -214,9 +214,9 @@ % for cur_dir in rootDir: <td>${cur_dir}</td> % if rootDir[cur_dir] is not False: - <td>${rootDir[cur_dir]} MB</td> + <td align="middle">${rootDir[cur_dir]}</td> % else: - <td><i>Missing</i></td> + <td align="middle"><i>Missing</i></td> % endif </tr> % endfor diff --git a/sickbeard/__init__.py b/sickbeard/__init__.py index e5070d571b5aac1afdce6e1610710b301a59abf8..790bf8c97fc631918c08c96bad54b4bfabca096d 100644 --- a/sickbeard/__init__.py +++ b/sickbeard/__init__.py @@ -270,6 +270,7 @@ ADD_SHOWS_WO_DIR = False CREATE_MISSING_SHOW_DIRS = False RENAME_EPISODES = False AIRDATE_EPISODES = False +FILE_TIMESTAMP_TIMEZONE = None PROCESS_AUTOMATICALLY = False NO_DELETE = False KEEP_PROCESSED_DIR = False @@ -536,6 +537,7 @@ EXTRA_SCRIPTS = [] IGNORE_WORDS = "german,french,core2hd,dutch,swedish,reenc,MrLss" REQUIRE_WORDS = "" +IGNORED_SUBS_LIST = "dk,fin,heb,kor,nor,nordic,pl,swe" SYNC_FILES = "!sync,lftp-pget-status,part,bts,!qb" CALENDAR_UNPROTECTED = False @@ -590,7 +592,7 @@ def initialize(consoleLogging=True): KEEP_PROCESSED_DIR, PROCESS_METHOD, DELRARCONTENTS, TV_DOWNLOAD_DIR, MIN_DAILYSEARCH_FREQUENCY, DEFAULT_UPDATE_FREQUENCY, DEFAULT_SHOWUPDATE_HOUR, MIN_UPDATE_FREQUENCY, UPDATE_FREQUENCY, \ showQueueScheduler, searchQueueScheduler, ROOT_DIRS, CACHE_DIR, ACTUAL_CACHE_DIR, TIMEZONE_DISPLAY, \ NAMING_PATTERN, NAMING_MULTI_EP, NAMING_ANIME_MULTI_EP, NAMING_FORCE_FOLDERS, NAMING_ABD_PATTERN, NAMING_CUSTOM_ABD, NAMING_SPORTS_PATTERN, NAMING_CUSTOM_SPORTS, NAMING_ANIME_PATTERN, NAMING_CUSTOM_ANIME, NAMING_STRIP_YEAR, \ - RENAME_EPISODES, AIRDATE_EPISODES, properFinderScheduler, PROVIDER_ORDER, autoPostProcesserScheduler, \ + RENAME_EPISODES, AIRDATE_EPISODES, FILE_TIMESTAMP_TIMEZONE, properFinderScheduler, PROVIDER_ORDER, autoPostProcesserScheduler, \ WOMBLE, BINSEARCH, OMGWTFNZBS, OMGWTFNZBS_USERNAME, OMGWTFNZBS_APIKEY, providerList, newznabProviderList, torrentRssProviderList, \ EXTRA_SCRIPTS, USE_TWITTER, TWITTER_USERNAME, TWITTER_PASSWORD, TWITTER_PREFIX, DAILYSEARCH_FREQUENCY, TWITTER_DMTO, TWITTER_USEDM, \ USE_BOXCAR, BOXCAR_USERNAME, BOXCAR_PASSWORD, BOXCAR_NOTIFY_ONDOWNLOAD, BOXCAR_NOTIFY_ONSUBTITLEDOWNLOAD, BOXCAR_NOTIFY_ONSNATCH, \ @@ -603,7 +605,7 @@ def initialize(consoleLogging=True): NEWZBIN, NEWZBIN_USERNAME, NEWZBIN_PASSWORD, GIT_PATH, MOVE_ASSOCIATED_FILES, SYNC_FILES, POSTPONE_IF_SYNC_FILES, dailySearchScheduler, NFO_RENAME, \ GUI_NAME, HOME_LAYOUT, HISTORY_LAYOUT, DISPLAY_SHOW_SPECIALS, COMING_EPS_LAYOUT, COMING_EPS_SORT, COMING_EPS_DISPLAY_PAUSED, COMING_EPS_MISSED_RANGE, DISPLAY_FILESIZE, FUZZY_DATING, TRIM_ZERO, DATE_PRESET, TIME_PRESET, TIME_PRESET_W_SECONDS, THEME_NAME, FILTER_ROW, \ POSTER_SORTBY, POSTER_SORTDIR, HISTORY_LIMIT, \ - METADATA_WDTV, METADATA_TIVO, METADATA_MEDE8ER, IGNORE_WORDS, REQUIRE_WORDS, CALENDAR_UNPROTECTED, NO_RESTART, CREATE_MISSING_SHOW_DIRS, \ + METADATA_WDTV, METADATA_TIVO, METADATA_MEDE8ER, IGNORE_WORDS, IGNORED_SUBS_LIST, REQUIRE_WORDS, CALENDAR_UNPROTECTED, NO_RESTART, CREATE_MISSING_SHOW_DIRS, \ ADD_SHOWS_WO_DIR, USE_SUBTITLES, SUBTITLES_LANGUAGES, SUBTITLES_DIR, SUBTITLES_SERVICES_LIST, SUBTITLES_SERVICES_ENABLED, SUBTITLES_HISTORY, SUBTITLES_FINDER_FREQUENCY, SUBTITLES_MULTI, EMBEDDED_SUBTITLES_ALL, SUBTITLES_EXTRA_SCRIPTS, subtitlesFinderScheduler, \ USE_FAILED_DOWNLOADS, DELETE_FAILED, ANON_REDIRECT, LOCALHOST_IP, TMDB_API_KEY, DEBUG, DEFAULT_PAGE, PROXY_SETTING, PROXY_INDEXERS, \ AUTOPOSTPROCESSER_FREQUENCY, SHOWUPDATE_HOUR, DEFAULT_AUTOPOSTPROCESSER_FREQUENCY, MIN_AUTOPOSTPROCESSER_FREQUENCY, \ @@ -903,6 +905,7 @@ def initialize(consoleLogging=True): UNPACK = bool(check_setting_int(CFG, 'General', 'unpack', 0)) RENAME_EPISODES = bool(check_setting_int(CFG, 'General', 'rename_episodes', 1)) AIRDATE_EPISODES = bool(check_setting_int(CFG, 'General', 'airdate_episodes', 0)) + FILE_TIMESTAMP_TIMEZONE = check_setting_str(CFG, 'General', 'file_timestamp_timezone', 'network') KEEP_PROCESSED_DIR = bool(check_setting_int(CFG, 'General', 'keep_processed_dir', 1)) PROCESS_METHOD = check_setting_str(CFG, 'General', 'process_method', 'copy' if KEEP_PROCESSED_DIR else 'move') DELRARCONTENTS = bool(check_setting_int(CFG, 'General', 'del_rar_contents', 0)) @@ -1143,6 +1146,7 @@ def initialize(consoleLogging=True): IGNORE_WORDS = check_setting_str(CFG, 'General', 'ignore_words', IGNORE_WORDS) REQUIRE_WORDS = check_setting_str(CFG, 'General', 'require_words', REQUIRE_WORDS) + IGNORED_SUBS_LIST = check_setting_str(CFG, 'General', 'ignored_subs_list', IGNORED_SUBS_LIST) CALENDAR_UNPROTECTED = bool(check_setting_int(CFG, 'General', 'calendar_unprotected', 0)) @@ -1744,6 +1748,7 @@ def save_config(): new_config['General']['unpack'] = int(UNPACK) new_config['General']['rename_episodes'] = int(RENAME_EPISODES) new_config['General']['airdate_episodes'] = int(AIRDATE_EPISODES) + new_config['General']['file_timestamp_timezone'] = FILE_TIMESTAMP_TIMEZONE new_config['General']['create_missing_show_dirs'] = int(CREATE_MISSING_SHOW_DIRS) new_config['General']['add_shows_wo_dir'] = int(ADD_SHOWS_WO_DIR) @@ -1751,6 +1756,7 @@ def save_config(): new_config['General']['git_path'] = GIT_PATH new_config['General']['ignore_words'] = IGNORE_WORDS new_config['General']['require_words'] = REQUIRE_WORDS + new_config['General']['ignored_subs_list'] = IGNORED_SUBS_LIST new_config['General']['calendar_unprotected'] = int(CALENDAR_UNPROTECTED) new_config['General']['no_restart'] = int(NO_RESTART) new_config['General']['developer'] = int(DEVELOPER) diff --git a/sickbeard/helpers.py b/sickbeard/helpers.py index 238b1cd1f42231a7f85c2b50a408fb8a1f365eb3..adb65d150d3ef6f78715fb0b9c50f8ef84adde70 100644 --- a/sickbeard/helpers.py +++ b/sickbeard/helpers.py @@ -457,26 +457,6 @@ def searchIndexerForShowID(regShowName, indexer=None, indexer_id=None, ui=None): return (None, None, None) - -def sizeof_fmt(num): - """ - >>> sizeof_fmt(2) - '2.0 bytes' - >>> sizeof_fmt(1024) - '1.0 KB' - >>> sizeof_fmt(2048) - '2.0 KB' - >>> sizeof_fmt(2**20) - '1.0 MB' - >>> sizeof_fmt(1234567) - '1.2 MB' - """ - for x in ['bytes', 'KB', 'MB', 'GB', 'TB']: - if num < 1024.0: - return "%3.1f %s" % (num, x) - num /= 1024.0 - - def listMediaFiles(path): """ Get a list of files possibly containing media in a path @@ -1758,24 +1738,10 @@ def generateApiKey(): def pretty_filesize(file_bytes): """Return humanly formatted sizes from bytes""" - - file_bytes = float(file_bytes) - if file_bytes >= 1099511627776: - terabytes = file_bytes / 1099511627776 - size = '%.2f TB' % terabytes - elif file_bytes >= 1073741824: - gigabytes = file_bytes / 1073741824 - size = '%.2f GB' % gigabytes - elif file_bytes >= 1048576: - megabytes = file_bytes / 1048576 - size = '%.2f MB' % megabytes - elif file_bytes >= 1024: - kilobytes = file_bytes / 1024 - size = '%.2f KB' % kilobytes - else: - size = '%.2f b' % file_bytes - - return size + for mod in ['B', 'KB', 'MB', 'GB', 'TB', 'PB']: + if file_bytes < 1024.00: + return "%3.2f %s" % (file_bytes, mod) + file_bytes /= 1024.00 if __name__ == '__main__': import doctest @@ -1851,7 +1817,7 @@ def verify_freespace(src, dest, oldfile=None): if diskfree > neededspace: return True else: - logger.log("Not enough free space: Needed: %s bytes ( %s ), found: %s bytes ( %s )" % ( neededspace, pretty_filesize(neededspace), diskfree, pretty_filesize(diskfree) ) , + logger.log("Not enough free space: Needed: %s bytes ( %s ), found: %s bytes ( %s )" % ( neededspace, pretty_filesize(neededspace), diskfree, pretty_filesize(diskfree) ) , logger.WARNING) return False @@ -1910,17 +1876,16 @@ def isFileLocked(checkfile, writeLockCheck=False): def getDiskSpaceUsage(diskPath=None): ''' - returns the free space in MB for a given path or False if no path given + returns the free space in human readable bytes for a given path or False if no path given :param diskPath: the filesystem path being checked ''' - if diskPath and os.path.exists(diskPath): if platform.system() == 'Windows': free_bytes = ctypes.c_ulonglong(0) ctypes.windll.kernel32.GetDiskFreeSpaceExW(ctypes.c_wchar_p(diskPath), None, None, ctypes.pointer(free_bytes)) - return free_bytes.value / 1024 / 1024 + return pretty_filesize(free_bytes.value) else: st = os.statvfs(diskPath) - return st.f_bavail * st.f_frsize / 1024 / 1024 + return pretty_filesize(st.f_bavail * st.f_frsize) else: return False diff --git a/sickbeard/notifiers/kodi.py b/sickbeard/notifiers/kodi.py index 124804af2530c3b48c6cd81c10df1423890a8c00..bec870a00f1e11c468eec7243ee9b3196de73fe4 100644 --- a/sickbeard/notifiers/kodi.py +++ b/sickbeard/notifiers/kodi.py @@ -130,7 +130,7 @@ class KODINotifier: kodiapi = self._get_kodi_version(curHost, username, password) if kodiapi: - if (kodiapi <= 4): + if kodiapi <= 4: logger.log(u"Detected KODI version <= 11, using KODI HTTP API", logger.DEBUG) command = {'command': 'ExecBuiltIn', 'parameter': 'Notification(' + title.encode("utf-8") + ',' + message.encode( @@ -143,7 +143,7 @@ class KODINotifier: command = '{"jsonrpc":"2.0","method":"GUI.ShowNotification","params":{"title":"%s","message":"%s", "image": "%s"},"id":1}' % ( title.encode("utf-8"), message.encode("utf-8"), self.sr_logo_url) notifyResult = self._send_to_kodi_json(command, curHost, username, password) - if notifyResult.get('result'): + if notifyResult and notifyResult.get('result'): result += curHost + ':' + notifyResult["result"].decode(sickbeard.SYS_ENCODING) else: if sickbeard.KODI_ALWAYS_ON or force: @@ -172,7 +172,7 @@ class KODINotifier: kodiapi = self._get_kodi_version(host, sickbeard.KODI_USERNAME, sickbeard.KODI_PASSWORD) if kodiapi: - if (kodiapi <= 4): + if kodiapi <= 4: # try to update for just the show, if it fails, do full update if enabled if not self._update_library(host, showName) and sickbeard.KODI_UPDATE_FULL: logger.log(u"Single show update failed, falling back to full update", logger.DEBUG) @@ -222,7 +222,7 @@ class KODINotifier: return False for key in command: - if type(command[key]) == unicode: + if isinstance(command[key], unicode): command[key] = command[key].encode('utf-8') enc_command = urllib.urlencode(command) @@ -240,7 +240,12 @@ class KODINotifier: else: logger.log(u"Contacting KODI via url: " + ss(url), logger.DEBUG) - response = urllib2.urlopen(req) + try: + response = urllib2.urlopen(req) + except (httplib.BadStatusLine, urllib2.URLError) as e: + logger.log(u"Couldn't contact KODI HTTP at %r : %r" % (url, ex(e)), logger.DEBUG) + return False + result = response.read().decode(sickbeard.SYS_ENCODING) response.close() @@ -248,8 +253,7 @@ class KODINotifier: return result except Exception as e: - logger.log(u"Warning: Couldn't contact KODI HTTP at " + ss(url) + " " + str(e), - logger.WARNING) + logger.log(u"Couldn't contact KODI HTTP at %r : %r" % (url, ex(e)), logger.DEBUG) return False def _update_library(self, host=None, showName=None): @@ -466,7 +470,7 @@ class KODINotifier: del shows # we didn't find the show (exact match), thus revert to just doing a full update if enabled - if (tvshowid == -1): + if tvshowid == -1: logger.log(u'Exact show name not matched in KODI TV show list', logger.DEBUG) return False @@ -532,10 +536,10 @@ class KODINotifier: if sickbeard.KODI_NOTIFY_ONSUBTITLEDOWNLOAD: self._notify_kodi(ep_name + ": " + lang, common.notifyStrings[common.NOTIFY_SUBTITLE_DOWNLOAD]) - def notify_git_update(self, new_version = "??"): + def notify_git_update(self, new_version="??"): if sickbeard.USE_KODI: - update_text=common.notifyStrings[common.NOTIFY_GIT_UPDATE_TEXT] - title=common.notifyStrings[common.NOTIFY_GIT_UPDATE] + update_text = common.notifyStrings[common.NOTIFY_GIT_UPDATE_TEXT] + title = common.notifyStrings[common.NOTIFY_GIT_UPDATE] self._notify_kodi(update_text + new_version, title) def test_notify(self, host, username, password): diff --git a/sickbeard/postProcessor.py b/sickbeard/postProcessor.py index 788be3ea883a58016fb23c6df7f3d3669651bb92..93358b248cb59a47d568df53982620e6beac0f92 100644 --- a/sickbeard/postProcessor.py +++ b/sickbeard/postProcessor.py @@ -810,18 +810,18 @@ class PostProcessor(object): old_ep_status, old_ep_quality = common.Quality.splitCompositeStatus(ep_obj.status) - # if SB downloaded this on purpose we likely have a priority download + # if SR downloaded this on purpose we likely have a priority download if self.in_history or ep_obj.status in common.Quality.SNATCHED + common.Quality.SNATCHED_PROPER + common.Quality.SNATCHED_BEST: # if the episode is still in a snatched status, then we can assume we want this if ep_obj.status in common.Quality.SNATCHED + common.Quality.SNATCHED_PROPER + common.Quality.SNATCHED_BEST: - self._log(u"SB snatched this episode and it is not processed before", logger.DEBUG) + self._log(u"SR snatched this episode and it is not processed before", logger.DEBUG) return True # if it's not snatched, we only want it if the new quality is higher or if it's a proper of equal or higher quality if new_ep_quality > old_ep_quality and new_ep_quality != common.Quality.UNKNOWN: - self._log(u"SB snatched this episode and it is a higher quality so I'm marking it as priority", logger.DEBUG) + self._log(u"SR snatched this episode and it is a higher quality so I'm marking it as priority", logger.DEBUG) return True if self.is_proper and new_ep_quality >= old_ep_quality and new_ep_quality != common.Quality.UNKNOWN: - self._log(u"SB snatched this episode and it is a proper of equal or higher quality so I'm marking it as priority", logger.DEBUG) + self._log(u"SR snatched this episode and it is a proper of equal or higher quality so I'm marking it as priority", logger.DEBUG) return True return False diff --git a/sickbeard/providers/alpharatio.py b/sickbeard/providers/alpharatio.py index 001d62404edefcf9e87ef4cf921ba48b40e2dc37..c6d6a8d0902fa87a8b475b1b0f8300e41dae9ece 100644 --- a/sickbeard/providers/alpharatio.py +++ b/sickbeard/providers/alpharatio.py @@ -32,7 +32,6 @@ class AlphaRatioProvider(generic.TorrentProvider): generic.TorrentProvider.__init__(self, "AlphaRatio") self.supportsBacklog = True - self.public = False self.username = None self.password = None @@ -40,17 +39,19 @@ class AlphaRatioProvider(generic.TorrentProvider): self.minseed = None self.minleech = None - self.cache = AlphaRatioCache(self) - self.urls = {'base_url': 'http://alpharatio.cc/', 'login': 'http://alpharatio.cc/login.php', 'detail': 'http://alpharatio.cc/torrents.php?torrentid=%s', 'search': 'http://alpharatio.cc/torrents.php?searchstr=%s%s', 'download': 'http://alpharatio.cc/%s'} + self.url = self.urls['base_url'] + self.catagories = "&filter_cat[1]=1&filter_cat[2]=1&filter_cat[3]=1&filter_cat[4]=1&filter_cat[5]=1" - self.url = self.urls['base_url'] + self.proper_strings = ['PROPER', 'REPACK'] + + self.cache = AlphaRatioCache(self) def isEnabled(self): return self.enabled diff --git a/sickbeard/providers/animenzb.py b/sickbeard/providers/animenzb.py index 24873a19050197b8cff3c020725ade7aa2398743..5a384a6221dc28a50b377906db6eb97decc03692 100644 --- a/sickbeard/providers/animenzb.py +++ b/sickbeard/providers/animenzb.py @@ -19,14 +19,14 @@ import urllib import datetime -import generic from sickbeard import classes from sickbeard import show_name_helpers from sickbeard import logger -from sickbeard.common import * + from sickbeard import tvcache +from sickbeard.providers import generic class animenzb(generic.NZBProvider): diff --git a/sickbeard/providers/bitcannon.py b/sickbeard/providers/bitcannon.py index 730bf615821a6ce8ecc725d47e4a5e50d921b15f..d5feddffda304be9afb574f84d974671406ea862 100644 --- a/sickbeard/providers/bitcannon.py +++ b/sickbeard/providers/bitcannon.py @@ -112,7 +112,7 @@ class BitCannonCache(tvcache.TVCache): def _getRSSData(self): return {'entries': []} - #search_params = {'RSS': ['']} - #return {'entries': self.provider._doSearch(search_params)} + #search_strings = {'RSS': ['']} + #return {'entries': self.provider._doSearch(search_strings)} provider = BitCannonProvider() diff --git a/sickbeard/providers/bitsoup.py b/sickbeard/providers/bitsoup.py index 4c0076e6c344cfea761bb3179c8c4f269b2f66b2..0867cacb08736d4492406983f831f6f5ea5e6a56 100644 --- a/sickbeard/providers/bitsoup.py +++ b/sickbeard/providers/bitsoup.py @@ -39,7 +39,7 @@ class BitSoupProvider(generic.TorrentProvider): self.url = self.urls['base_url'] self.supportsBacklog = True - self.public = False + self.username = None self.password = None self.ratio = None diff --git a/sickbeard/providers/bluetigers.py b/sickbeard/providers/bluetigers.py index 3e6b3236ad3cdb3edfae36baa5945d07e1174d28..98039041a35fa10528dda0e6abc06677c3c7b130 100644 --- a/sickbeard/providers/bluetigers.py +++ b/sickbeard/providers/bluetigers.py @@ -34,7 +34,7 @@ class BLUETIGERSProvider(generic.TorrentProvider): generic.TorrentProvider.__init__(self, "BLUETIGERS") self.supportsBacklog = True - self.public = False + self.username = None self.password = None self.ratio = None diff --git a/sickbeard/providers/btn.py b/sickbeard/providers/btn.py index 4154239733b175a995e180e27bd36a313a55958f..6647c38167d1fa02e82996c864911393b72466e5 100644 --- a/sickbeard/providers/btn.py +++ b/sickbeard/providers/btn.py @@ -39,7 +39,7 @@ class BTNProvider(generic.TorrentProvider): generic.TorrentProvider.__init__(self, "BTN") self.supportsBacklog = True - self.public = False + self.supportsAbsoluteNumbering = True self.api_key = None diff --git a/sickbeard/providers/cpasbien.py b/sickbeard/providers/cpasbien.py index 486f22836ff366497741cee3249bb36f0b470c85..6baa067b49edd9286fa23f0b8fdf6005b93c141e 100644 --- a/sickbeard/providers/cpasbien.py +++ b/sickbeard/providers/cpasbien.py @@ -34,9 +34,11 @@ class CpasbienProvider(generic.TorrentProvider): self.supportsBacklog = True self.public = True self.ratio = None - self.cache = CpasbienCache(self) self.url = "http://www.cpasbien.pw" + self.proper_strings = ['PROPER', 'REPACK'] + + self.cache = CpasbienCache(self) def isEnabled(self): return self.enabled diff --git a/sickbeard/providers/fnt.py b/sickbeard/providers/fnt.py index 228d8b5042ba29e69b123b7b912d85babd5dbd66..4d15496433a3115d5f089fb8592195a5b1b59d6f 100644 --- a/sickbeard/providers/fnt.py +++ b/sickbeard/providers/fnt.py @@ -32,7 +32,7 @@ class FNTProvider(generic.TorrentProvider): generic.TorrentProvider.__init__(self, "FNT") self.supportsBacklog = True - self.public = False + self.username = None self.password = None self.ratio = None diff --git a/sickbeard/providers/frenchtorrentdb.py b/sickbeard/providers/frenchtorrentdb.py index 51e1eee00ff9ab1c0e2ec6876bfa1cd81b82663b..90e1e2e00adf6d3c568ee31421b27ac25594d0d7 100644 --- a/sickbeard/providers/frenchtorrentdb.py +++ b/sickbeard/providers/frenchtorrentdb.py @@ -31,7 +31,7 @@ class FrenchTorrentDBProvider(generic.TorrentProvider): generic.TorrentProvider.__init__(self, "FrenchTorrentDB") self.supportsBacklog = True - self.public = False + self.urls = { 'base_url': 'http://www.frenchtorrentdb.com', diff --git a/sickbeard/providers/freshontv.py b/sickbeard/providers/freshontv.py index 2e3bad81cf218c66eaff0454fcf86307a47b76c6..ec355f82fcbd6c08692cae77e99c7fd7016d9364 100644 --- a/sickbeard/providers/freshontv.py +++ b/sickbeard/providers/freshontv.py @@ -34,7 +34,7 @@ class FreshOnTVProvider(generic.TorrentProvider): generic.TorrentProvider.__init__(self, "FreshOnTV") self.supportsBacklog = True - self.public = False + self._uid = None self._hash = None diff --git a/sickbeard/providers/generic.py b/sickbeard/providers/generic.py index 6d7c7bcec8fb19f1512ba1f68f64a5ba44a0cf56..dbca1a58e24496180e514648399685778bf29f8e 100644 --- a/sickbeard/providers/generic.py +++ b/sickbeard/providers/generic.py @@ -55,7 +55,8 @@ class GenericProvider: self.proxyGlypeProxySSLwarning = None self.urls = {} self.url = '' - self.public = True + + self.public = False self.show = None @@ -87,6 +88,8 @@ class GenericProvider: shuffle(self.btCacheURLS) + self.proper_strings = ['PROPER|REPACK'] + def getID(self): return GenericProvider.makeID(self.name) @@ -294,7 +297,7 @@ class GenericProvider: url = item.get('link') if url: - url = url.replace('&', '&') + url = url.replace('&', '&').replace('%26tr%3D', '&tr=') return title, url @@ -634,9 +637,7 @@ class TorrentProvider(GenericProvider): return [search_string] def _clean_title_from_provider(self, title): - if title: - title = u'' + title.replace(' ', '.') - return title + return (title or '').replace(' ', '.') def findPropers(self, search_date=datetime.datetime.today()): @@ -647,22 +648,19 @@ class TorrentProvider(GenericProvider): 'SELECT s.show_name, e.showid, e.season, e.episode, e.status, e.airdate FROM tv_episodes AS e' + ' INNER JOIN tv_shows AS s ON (e.showid = s.indexer_id)' + ' WHERE e.airdate >= ' + str(search_date.toordinal()) + - ' AND (e.status IN (' + ','.join([str(x) for x in Quality.DOWNLOADED]) + ')' + - ' OR (e.status IN (' + ','.join([str(x) for x in Quality.SNATCHED]) + ')))' + ' AND e.status IN (' + ','.join([str(x) for x in Quality.DOWNLOADED + Quality.SNATCHED]) + ')' ) - if not sqlResults: - return [] - - for sqlshow in sqlResults: + for sqlshow in sqlResults or []: show = helpers.findCertainShow(sickbeard.showList, int(sqlshow["showid"])) if show: curEp = show.getEpisode(int(sqlshow["season"]), int(sqlshow["episode"])) - searchString = self._get_episode_search_strings(curEp, add_string='PROPER|REPACK') + for term in self.proper_strings: + searchString = self._get_episode_search_strings(curEp, add_string=term) - for item in self._doSearch(searchString[0]): - title, url = self._get_title_and_url(item) - results.append(classes.Proper(title, url, datetime.datetime.today(), show)) + for item in self._doSearch(searchString[0]): + title, url = self._get_title_and_url(item) + results.append(classes.Proper(title, url, datetime.datetime.today(), show)) return results diff --git a/sickbeard/providers/hdbits.py b/sickbeard/providers/hdbits.py index f11c677182353447056e8995f17d0843da607346..a4eb0acf2eee061ed5e727de34c223f8f11dc826 100644 --- a/sickbeard/providers/hdbits.py +++ b/sickbeard/providers/hdbits.py @@ -34,9 +34,8 @@ class HDBitsProvider(generic.TorrentProvider): generic.TorrentProvider.__init__(self, "HDBits") self.supportsBacklog = True - self.public = False - self.enabled = False + self.username = None self.passkey = None self.ratio = None diff --git a/sickbeard/providers/hdtorrents.py b/sickbeard/providers/hdtorrents.py index b23301df34d1e61c07f198332b1038fbea6b8c6f..619e96a51a48b75380bee8e593a528143b04cea2 100644 --- a/sickbeard/providers/hdtorrents.py +++ b/sickbeard/providers/hdtorrents.py @@ -18,23 +18,13 @@ # along with SickRage. If not, see <http://www.gnu.org/licenses/>. import re -import sickbeard -import generic import urllib -from sickbeard.common import Quality -from sickbeard import logger -from sickbeard import tvcache -from sickbeard import db -from sickbeard import classes -from sickbeard import helpers -from sickbeard import show_name_helpers -from sickrage.helper.exceptions import AuthException import requests from bs4 import BeautifulSoup -from unidecode import unidecode -from sickbeard.helpers import sanitizeSceneName -from datetime import datetime -import traceback + +from sickbeard import logger +from sickbeard import tvcache +from sickbeard.providers import generic class HDTorrentsProvider(generic.TorrentProvider): def __init__(self): @@ -42,7 +32,6 @@ class HDTorrentsProvider(generic.TorrentProvider): generic.TorrentProvider.__init__(self, "HDTorrents") self.supportsBacklog = True - self.public = False self.username = None self.password = None @@ -54,14 +43,15 @@ class HDTorrentsProvider(generic.TorrentProvider): 'login': 'https://hd-torrents.org/login.php', 'search': 'https://hd-torrents.org/torrents.php?search=%s&active=1&options=0%s', 'rss': 'https://hd-torrents.org/torrents.php?search=&active=1&options=0%s', - 'home': 'https://hd-torrents.org/%s' - } + 'home': 'https://hd-torrents.org/%s'} self.url = self.urls['base_url'] + self.categories = "&category[]=59&category[]=60&category[]=30&category[]=38" + self.proper_strings = ['PROPER', 'REPACK'] + self.cache = HDTorrentsCache(self) - self.categories = "&category[]=59&category[]=60&category[]=30&category[]=38" def isEnabled(self): return self.enabled @@ -82,7 +72,7 @@ class HDTorrentsProvider(generic.TorrentProvider): 'pwd': self.password, 'submit': 'Confirm'} - response = self.getURL(self.urls['login'], post_data=login_params, timeout=30) + response = self.getURL(self.urls['login'], post_data=login_params, timeout=30) if not response: logger.log(u"Unable to connect to provider", logger.WARNING) return False @@ -181,7 +171,7 @@ class HDTorrentsProvider(generic.TorrentProvider): if not size: size = -1 - except: + except Exception: logger.log(u"Failed parsing provider. Traceback: %s" % traceback.format_exc(), logger.ERROR) if not all([title, download_url]): @@ -209,41 +199,6 @@ class HDTorrentsProvider(generic.TorrentProvider): return results - def findPropers(self, search_date=datetime.today()): - - results = [] - - myDB = db.DBConnection() - sqlResults = myDB.select( - 'SELECT s.show_name, e.showid, e.season, e.episode, e.status, e.airdate FROM tv_episodes AS e' + - ' INNER JOIN tv_shows AS s ON (e.showid = s.indexer_id)' + - ' WHERE e.airdate >= ' + str(search_date.toordinal()) + - ' AND (e.status IN (' + ','.join([str(x) for x in Quality.DOWNLOADED]) + ')' + - ' OR (e.status IN (' + ','.join([str(x) for x in Quality.SNATCHED]) + ')))' - ) - - if not sqlResults: - return [] - - for sqlshow in sqlResults: - self.show = curshow = helpers.findCertainShow(sickbeard.showList, int(sqlshow["showid"])) - if not self.show: continue - curEp = curshow.getEpisode(int(sqlshow["season"]), int(sqlshow["episode"])) - - proper_searchString = self._get_episode_search_strings(curEp, add_string='PROPER') - - for item in self._doSearch(proper_searchString[0]): - title, url = self._get_title_and_url(item) - results.append(classes.Proper(title, url, datetime.today(), self.show)) - - repack_searchString = self._get_episode_search_strings(curEp, add_string='REPACK') - - for item in self._doSearch(repack_searchString[0]): - title, url = self._get_title_and_url(item) - results.append(classes.Proper(title, url, datetime.today(), self.show)) - - return results - def seedRatio(self): return self.ratio diff --git a/sickbeard/providers/hounddawgs.py b/sickbeard/providers/hounddawgs.py index 2f4706e05fab58a4b15e716c98e72422e7bf0497..6c9f18c1a7f64e3e7c69442506fe3d56f25a43c3 100644 --- a/sickbeard/providers/hounddawgs.py +++ b/sickbeard/providers/hounddawgs.py @@ -1,7 +1,7 @@ # Author: Idan Gutman # URL: http://code.google.com/p/sickbeard/ # -# This file is part of SickRage. +# This file is part of SickRage. # # SickRage is free software: you can redistribute it and/or modify # it under the terms of the GNU General Public License as published by @@ -18,19 +18,11 @@ import re import traceback -import datetime -import sickbeard -import generic -from sickbeard.common import Quality from sickbeard import logger from sickbeard import tvcache -from sickbeard import db -from sickbeard import classes -from sickbeard import helpers -from sickbeard import show_name_helpers from sickbeard.bs4_parser import BS4Parser -from sickbeard.helpers import sanitizeSceneName +from sickbeard.providers import generic class HoundDawgsProvider(generic.TorrentProvider): @@ -39,9 +31,7 @@ class HoundDawgsProvider(generic.TorrentProvider): generic.TorrentProvider.__init__(self, "HoundDawgs") self.supportsBacklog = True - self.public = False - self.enabled = False self.username = None self.password = None self.ratio = None @@ -51,9 +41,8 @@ class HoundDawgsProvider(generic.TorrentProvider): self.cache = HoundDawgsCache(self) self.urls = {'base_url': 'https://hounddawgs.org/', - 'search': 'https://hounddawgs.org/torrents.php', - 'login': 'https://hounddawgs.org/login.php', - } + 'search': 'https://hounddawgs.org/torrents.php', + 'login': 'https://hounddawgs.org/login.php'} self.url = self.urls['base_url'] @@ -81,11 +70,10 @@ class HoundDawgsProvider(generic.TorrentProvider): login_params = {'username': self.username, 'password': self.password, 'keeplogged': 'on', - 'login': 'Login', - } + 'login': 'Login'} self.getURL(self.urls['base_url'], timeout=30) - response = self.getURL(self.urls['login'], post_data=login_params, timeout=30) + response = self.getURL(self.urls['login'], post_data=login_params, timeout=30) if not response: logger.log(u"Unable to connect to provider", logger.WARNING) return False @@ -118,7 +106,7 @@ class HoundDawgsProvider(generic.TorrentProvider): data = self.getURL(self.urls['search'], params=self.search_params) strTableStart = "<table class=\"torrent_table" - startTableIndex=data.find(strTableStart) + startTableIndex = data.find(strTableStart) trimmedData = data[startTableIndex:] if not trimmedData: continue @@ -144,7 +132,7 @@ class HoundDawgsProvider(generic.TorrentProvider): allAs = (torrent[1]).find_all('a') try: - link = self.urls['base_url'] + allAs[2].attrs['href'] + #link = self.urls['base_url'] + allAs[2].attrs['href'] #url = result.find('td', attrs={'class': 'quickdownload'}).find('a') title = allAs[2].string #Trimming title so accepted by scene check(Feature has been rewuestet i forum) @@ -159,7 +147,6 @@ class HoundDawgsProvider(generic.TorrentProvider): title = title.replace("Subs.", "") download_url = self.urls['base_url']+allAs[0].attrs['href'] - id = link.replace(self.urls['base_url']+'torrents.php?id=','') #FIXME size = -1 seeders = 1 @@ -193,35 +180,6 @@ class HoundDawgsProvider(generic.TorrentProvider): return results - def findPropers(self, search_date=datetime.datetime.today()): - - results = [] - - myDB = db.DBConnection() - sqlResults = myDB.select( - 'SELECT s.show_name, e.showid, e.season, e.episode, e.status, e.airdate FROM tv_episodes AS e' + - ' INNER JOIN tv_shows AS s ON (e.showid = s.indexer_id)' + - ' WHERE e.airdate >= ' + str(search_date.toordinal()) + - ' AND (e.status IN (' + ','.join([str(x) for x in Quality.DOWNLOADED]) + ')' + - ' OR (e.status IN (' + ','.join([str(x) for x in Quality.SNATCHED]) + ')))' - ) - - if not sqlResults: - return [] - - for sqlshow in sqlResults: - self.show = helpers.findCertainShow(sickbeard.showList, int(sqlshow["showid"])) - if self.show: - curEp = self.show.getEpisode(int(sqlshow["season"]), int(sqlshow["episode"])) - - searchString = self._get_episode_search_strings(curEp, add_string='PROPER|REPACK') - - for item in self._doSearch(searchString[0]): - title, url = self._get_title_and_url(item) - results.append(classes.Proper(title, url, datetime.datetime.today(), self.show)) - - return results - def seedRatio(self): return self.ratio diff --git a/sickbeard/providers/iptorrents.py b/sickbeard/providers/iptorrents.py index 50d86841ae5ea67ea6f93e1a411798ca26399a61..de57748862301eccce09c8f83ce7c2bd1031d998 100644 --- a/sickbeard/providers/iptorrents.py +++ b/sickbeard/providers/iptorrents.py @@ -17,7 +17,6 @@ # along with SickRage. If not, see <http://www.gnu.org/licenses/>. import re -import traceback from sickbeard.providers import generic from sickbeard import logger from sickbeard import tvcache @@ -30,7 +29,7 @@ class IPTorrentsProvider(generic.TorrentProvider): generic.TorrentProvider.__init__(self, "IPTorrents") self.supportsBacklog = True - self.public = False + self.username = None self.password = None diff --git a/sickbeard/providers/kat.py b/sickbeard/providers/kat.py index c614dcd775543263240117a8eca48ec811468355..5200f2c13a10feba1a015ac78afe164575efbf46 100644 --- a/sickbeard/providers/kat.py +++ b/sickbeard/providers/kat.py @@ -30,10 +30,6 @@ import HTMLParser import sickbeard from sickbeard import logger from sickbeard import tvcache -from sickbeard import helpers -from sickbeard import db -from sickbeard import classes -from sickbeard.common import Quality from sickbeard.common import USER_AGENT from sickbeard.providers import generic from xml.parsers.expat import ExpatError @@ -46,7 +42,6 @@ class KATProvider(generic.TorrentProvider): self.supportsBacklog = True self.public = True - self.enabled = False self.confirmed = True self.ratio = None self.minseed = None @@ -168,33 +163,6 @@ class KATProvider(generic.TorrentProvider): return results - def findPropers(self, search_date=datetime.datetime.today()-datetime.timedelta(days=1)): - results = [] - - myDB = db.DBConnection() - sqlResults = myDB.select( - 'SELECT s.show_name, e.showid, e.season, e.episode, e.status, e.airdate, s.indexer FROM tv_episodes AS e' + - ' INNER JOIN tv_shows AS s ON (e.showid = s.indexer_id)' + - ' WHERE e.airdate >= ' + str(search_date.toordinal()) + - ' AND (e.status IN (' + ','.join([str(x) for x in Quality.DOWNLOADED]) + ')' + - ' OR (e.status IN (' + ','.join([str(x) for x in Quality.SNATCHED]) + ')))' - ) - - for sqlshow in sqlResults or []: - show = helpers.findCertainShow(sickbeard.showList, int(sqlshow["showid"])) - if show: - curEp = show.getEpisode(int(sqlshow["season"]), int(sqlshow["episode"])) - - searchStrings = self._get_episode_search_strings(curEp, add_string='PROPER|REPACK') - - for item in self._doSearch(searchStrings[0]): - title, url = self._get_title_and_url(item) - pubdate = item[6] - - results.append(classes.Proper(title, url, pubdate, show)) - - return results - def seedRatio(self): return self.ratio diff --git a/sickbeard/providers/libertalia.py b/sickbeard/providers/libertalia.py index b89299cc4ce0b07dffd483a4d2df4e08958f520a..1db4c54ff848385d8974b52a61a5c6e087e8fe79 100644 --- a/sickbeard/providers/libertalia.py +++ b/sickbeard/providers/libertalia.py @@ -20,24 +20,14 @@ # along with SickRage. If not, see <http://www.gnu.org/licenses/>. import re -import datetime -import sickbeard -import generic - import requests import cookielib import urllib -from sickbeard.bs4_parser import BS4Parser -from sickbeard.common import Quality from sickbeard import logger -from sickbeard import show_name_helpers -from sickbeard import db -from sickbeard import helpers -from sickbeard import classes -from unidecode import unidecode -from sickbeard.helpers import sanitizeSceneName from sickbeard import tvcache +from sickbeard.providers import generic +from sickbeard.bs4_parser import BS4Parser class LibertaliaProvider(generic.TorrentProvider): @@ -46,7 +36,6 @@ class LibertaliaProvider(generic.TorrentProvider): generic.TorrentProvider.__init__(self, "Libertalia") self.supportsBacklog = True - self.public = False self.cj = cookielib.CookieJar() @@ -72,10 +61,9 @@ class LibertaliaProvider(generic.TorrentProvider): return True login_params = {'username': self.username, - 'password': self.password - } + 'password': self.password} - response = self.getURL(self.url + '/login.php', post_data=login_params, timeout=30) + response = self.getURL(self.url + '/login.php', post_data=login_params, timeout=30) if not response: logger.log(u"Unable to connect to provider", logger.WARNING) return False @@ -112,21 +100,21 @@ class LibertaliaProvider(generic.TorrentProvider): continue with BS4Parser(data, features=["html5lib", "permissive"]) as html: - resultsTable = html.find("table", { "class" : "torrent_table" }) + resultsTable = html.find("table", {"class" : "torrent_table"}) if resultsTable: - rows = resultsTable.findAll("tr" , {"class" : "torrent_row new "} ) # torrent_row new + rows = resultsTable.findAll("tr", {"class" : "torrent_row new "}) # torrent_row new for row in rows: #bypass first row because title only - columns = row.find('td', {"class" : "torrent_name"} ) - isvfclass = row.find('td', {"class" : "sprite-vf"} ) - isvostfrclass = row.find('td', {"class" : "sprite-vostfr"} ) - link = columns.find("a", href=re.compile("torrents")) + columns = row.find('td', {"class" : "torrent_name"}) + # isvfclass = row.find('td', {"class" : "sprite-vf"}) + #isvostfrclass = row.find('td', {"class" : "sprite-vostfr"}) + link = columns.find("a", href=re.compile("torrents")) if link: title = link.text - recherched=searchURL.replace(".","(.*)").replace(" ","(.*)").replace("'","(.*)") - download_url = row.find("a",href=re.compile("torrent_pass"))['href'] + #recherched = searchURL.replace(".", "(.*)").replace(" ", "(.*)").replace("'", "(.*)") + download_url = row.find("a", href=re.compile("torrent_pass"))['href'] #FIXME size = -1 seeders = 1 @@ -157,33 +145,6 @@ class LibertaliaProvider(generic.TorrentProvider): def seedRatio(self): return self.ratio - def findPropers(self, search_date=datetime.datetime.today()): - - results = [] - - myDB = db.DBConnection() - sqlResults = myDB.select( - 'SELECT s.show_name, e.showid, e.season, e.episode, e.status, e.airdate FROM tv_episodes AS e' + - ' INNER JOIN tv_shows AS s ON (e.showid = s.indexer_id)' + - ' WHERE e.airdate >= ' + str(search_date.toordinal()) + - ' AND (e.status IN (' + ','.join([str(x) for x in Quality.DOWNLOADED]) + ')' + - ' OR (e.status IN (' + ','.join([str(x) for x in Quality.SNATCHED]) + ')))' - ) - - if not sqlResults: - return [] - - for sqlshow in sqlResults: - self.show = helpers.findCertainShow(sickbeard.showList, int(sqlshow["showid"])) - if self.show: - curEp = self.show.getEpisode(int(sqlshow["season"]), int(sqlshow["episode"])) - search_params = self._get_episode_search_strings(curEp, add_string='PROPER|REPACK') - - for item in self._doSearch(search_params[0]): - title, url = self._get_title_and_url(item) - results.append(classes.Proper(title, url, datetime.datetime.today(), self.show)) - - return results class LibertaliaCache(tvcache.TVCache): def __init__(self, provider_obj): diff --git a/sickbeard/providers/morethantv.py b/sickbeard/providers/morethantv.py index 8f001e3b5e1806df5877fa6ccae2c6bd27247705..5b16ab01828c8abd6dbd12d9a98a10f3238a1dc8 100644 --- a/sickbeard/providers/morethantv.py +++ b/sickbeard/providers/morethantv.py @@ -1,7 +1,7 @@ # Author: Seamus Wassman # URL: http://code.google.com/p/sickbeard/ # -# This file is part of SickRage. +# This file is part of SickRage. # # SickRage is free software: you can redistribute it and/or modify # it under the terms of the GNU General Public License as published by @@ -21,22 +21,14 @@ # are some mistakes or things I could have done better. import re +import requests import traceback -import datetime -import sickbeard -import generic -from sickbeard.common import Quality + from sickbeard import logger from sickbeard import tvcache -from sickbeard import db -from sickbeard import classes -from sickbeard import helpers -from sickbeard import show_name_helpers -from sickrage.helper.exceptions import AuthException -import requests +from sickbeard.providers import generic from sickbeard.bs4_parser import BS4Parser -from unidecode import unidecode -from sickbeard.helpers import sanitizeSceneName +from sickrage.helper.exceptions import AuthException class MoreThanTVProvider(generic.TorrentProvider): @@ -46,9 +38,7 @@ class MoreThanTVProvider(generic.TorrentProvider): generic.TorrentProvider.__init__(self, "MoreThanTV") self.supportsBacklog = True - self.public = False - self.enabled = False self._uid = None self._hash = None self.username = None @@ -58,19 +48,20 @@ class MoreThanTVProvider(generic.TorrentProvider): self.minleech = None self.freeleech = False - self.cache = MoreThanTVCache(self) - self.urls = {'base_url': 'https://www.morethan.tv/', - 'login': 'https://www.morethan.tv/login.php', - 'detail': 'https://www.morethan.tv/torrents.php?id=%s', - 'search': 'https://www.morethan.tv/torrents.php?tags_type=1&order_by=time&order_way=desc&action=basic&searchsubmit=1&searchstr=%s', - 'download': 'https://www.morethan.tv/torrents.php?action=download&id=%s', - } + 'login': 'https://www.morethan.tv/login.php', + 'detail': 'https://www.morethan.tv/torrents.php?id=%s', + 'search': 'https://www.morethan.tv/torrents.php?tags_type=1&order_by=time&order_way=desc&action=basic&searchsubmit=1&searchstr=%s', + 'download': 'https://www.morethan.tv/torrents.php?action=download&id=%s'} self.url = self.urls['base_url'] self.cookies = None + self.proper_strings = ['PROPER', 'REPACK'] + + self.cache = MoreThanTVCache(self) + def isEnabled(self): return self.enabled @@ -90,10 +81,9 @@ class MoreThanTVProvider(generic.TorrentProvider): else: login_params = {'username': self.username, 'password': self.password, - 'login': 'submit' - } + 'login': 'submit'} - response = self.getURL(self.urls['login'], post_data=login_params, timeout=30) + response = self.getURL(self.urls['login'], post_data=login_params, timeout=30) if not response: logger.log(u"Unable to connect to provider", logger.WARNING) return False @@ -122,7 +112,7 @@ class MoreThanTVProvider(generic.TorrentProvider): logger.log(u"Search string: %s " % search_string, logger.DEBUG) searchURL = self.urls['search'] % (search_string) - logger.log(u"Search URL: %s" % searchURL, logger.DEBUG) + logger.log(u"Search URL: %s" % searchURL, logger.DEBUG) # returns top 15 results by default, expandable in user profile to 100 data = self.getURL(searchURL) @@ -142,17 +132,13 @@ class MoreThanTVProvider(generic.TorrentProvider): # skip colheader for result in torrent_rows[1:]: cells = result.findChildren('td') - - link = cells[1].find('a', attrs = {'title': 'Download'}) - - link_str = str(link['href']) + link = cells[1].find('a', attrs={'title': 'Download'}) #skip if torrent has been nuked due to poor quality if cells[1].find('img', alt='Nuked') != None: continue - torrent_id_long = link['href'].replace('torrents.php?action=download&id=', '') - id = torrent_id_long.split('&', 1)[0] + torrent_id_long = link['href'].replace('torrents.php?action=download&id=', '') try: if link.has_key('title'): @@ -197,35 +183,6 @@ class MoreThanTVProvider(generic.TorrentProvider): return results - def findPropers(self, search_date=datetime.datetime.today()): - - results = [] - - myDB = db.DBConnection() - sqlResults = myDB.select( - 'SELECT s.show_name, e.showid, e.season, e.episode, e.status, e.airdate FROM tv_episodes AS e' + - ' INNER JOIN tv_shows AS s ON (e.showid = s.indexer_id)' + - ' WHERE e.airdate >= ' + str(search_date.toordinal()) + - ' AND (e.status IN (' + ','.join([str(x) for x in Quality.DOWNLOADED]) + ')' + - ' OR (e.status IN (' + ','.join([str(x) for x in Quality.SNATCHED]) + ')))' - ) - - if not sqlResults: - return [] - - for sqlshow in sqlResults: - self.show = helpers.findCertainShow(sickbeard.showList, int(sqlshow["showid"])) - if self.show: - curEp = self.show.getEpisode(int(sqlshow["season"]), int(sqlshow["episode"])) - - searchString = self._get_episode_search_strings(curEp, add_string='PROPER|REPACK') - - for item in self._doSearch(searchString[0]): - title, url = self._get_title_and_url(item) - results.append(classes.Proper(title, url, datetime.datetime.today(), self.show)) - - return results - def seedRatio(self): return self.ratio diff --git a/sickbeard/providers/nextgen.py b/sickbeard/providers/nextgen.py index 9d13cdf0db6e5407ecfcb2e47677bc38b2dcc613..d075fe07d205c351fa717d53f9e04b540f5d8d9a 100644 --- a/sickbeard/providers/nextgen.py +++ b/sickbeard/providers/nextgen.py @@ -19,15 +19,9 @@ import traceback import urllib import time -import datetime -import sickbeard from sickbeard import logger from sickbeard import tvcache -from sickbeard import db -from sickbeard import classes -from sickbeard import helpers -from sickbeard.common import Quality from sickbeard.providers import generic from sickbeard.bs4_parser import BS4Parser @@ -39,9 +33,8 @@ class NextGenProvider(generic.TorrentProvider): generic.TorrentProvider.__init__(self, "NextGen") self.supportsBacklog = True - self.public = False - self.enabled = False + self.username = None self.password = None self.ratio = None @@ -49,9 +42,8 @@ class NextGenProvider(generic.TorrentProvider): self.cache = NextGenCache(self) self.urls = {'base_url': 'https://nxtgn.info/', - 'search': 'https://nxtgn.info/browse.php?search=%s&cat=0&incldead=0&modes=%s', - 'login_page': 'https://nxtgn.info/login.php', - } + 'search': 'https://nxtgn.info/browse.php?search=%s&cat=0&incldead=0&modes=%s', + 'login_page': 'https://nxtgn.info/login.php'} self.url = self.urls['base_url'] @@ -163,10 +155,7 @@ class NextGenProvider(generic.TorrentProvider): try: title = result.find('div', attrs={'id': 'torrent-udgivelse2-users'}).a['title'] - - dl = result.find('div', attrs={'id': 'torrent-download'}).a - download_url = self.urls['base_url'] + (dl['href'], dl['id'])['id' in dl] - + download_url = self.urls['base_url'] + result.find('div', attrs={'id': 'torrent-download'}).a['href'] seeders = int(result.find('div', attrs={'id' : 'torrent-seeders'}).text) leechers = int(result.find('div', attrs={'id' : 'torrent-leechers'}).text) size = self._convertSize(result.find('div', attrs={'id' : 'torrent-size'}).text) @@ -215,34 +204,6 @@ class NextGenProvider(generic.TorrentProvider): size = size * 1024**4 return int(size) - def findPropers(self, search_date=datetime.datetime.today()): - - results = [] - - myDB = db.DBConnection() - sqlResults = myDB.select( - 'SELECT s.show_name, e.showid, e.season, e.episode, e.status, e.airdate FROM tv_episodes AS e' + - ' INNER JOIN tv_shows AS s ON (e.showid = s.indexer_id)' + - ' WHERE e.airdate >= ' + str(search_date.toordinal()) + - ' AND (e.status IN (' + ','.join([str(x) for x in Quality.DOWNLOADED]) + ')' + - ' OR (e.status IN (' + ','.join([str(x) for x in Quality.SNATCHED]) + ')))' - ) - - if not sqlResults: - return [] - - for sqlshow in sqlResults: - self.show = helpers.findCertainShow(sickbeard.showList, int(sqlshow["showid"])) - if self.show: - curEp = self.show.getEpisode(int(sqlshow["season"]), int(sqlshow["episode"])) - searchString = self._get_episode_search_strings(curEp, add_string='PROPER|REPACK') - - for item in self._doSearch(searchString[0]): - title, url = self._get_title_and_url(item) - results.append(classes.Proper(title, url, datetime.datetime.today(), self.show)) - - return results - def seedRatio(self): return self.ratio diff --git a/sickbeard/providers/omgwtfnzbs.py b/sickbeard/providers/omgwtfnzbs.py index 8dddfab46b589bae299a9bbb56061c6cf7985b99..90f0b362b57f6179bba7086df08a3114b874146d 100644 --- a/sickbeard/providers/omgwtfnzbs.py +++ b/sickbeard/providers/omgwtfnzbs.py @@ -17,22 +17,20 @@ # along with SickRage. If not, see <http://www.gnu.org/licenses/>. import urllib +from datetime import datetime import sickbeard -import generic - from sickbeard import tvcache from sickbeard import classes from sickbeard import logger from sickbeard import show_name_helpers -from datetime import datetime -from sickrage.helper.exceptions import AuthException +from sickbeard.providers import generic class OmgwtfnzbsProvider(generic.NZBProvider): def __init__(self): generic.NZBProvider.__init__(self, "omgwtfnzbs") - self.enabled = False + self.username = None self.api_key = None self.cache = OmgwtfnzbsCache(self) @@ -41,7 +39,7 @@ class OmgwtfnzbsProvider(generic.NZBProvider): self.url = self.urls['base_url'] self.supportsBacklog = True - self.public = False + def isEnabled(self): return self.enabled @@ -141,7 +139,7 @@ class OmgwtfnzbsProvider(generic.NZBProvider): title, url = self._get_title_and_url(item) try: result_date = datetime.fromtimestamp(int(item['usenetage'])) - except: + except Exception: result_date = None if result_date: diff --git a/sickbeard/providers/rarbg.py b/sickbeard/providers/rarbg.py index c6581dc38856f9b0754b2f287ea0e5350dc3209a..76db48395987edf307393e8e9fea95c2f43bb20d 100644 --- a/sickbeard/providers/rarbg.py +++ b/sickbeard/providers/rarbg.py @@ -19,22 +19,15 @@ import traceback import re -import generic import datetime import json import time - -import sickbeard -from sickbeard.common import Quality, USER_AGENT from sickbeard import logger from sickbeard import tvcache -from sickbeard import show_name_helpers -from sickbeard import db -from sickbeard import helpers -from sickbeard import classes +from sickbeard.providers import generic +from sickbeard.common import USER_AGENT from sickbeard.indexers.indexer_config import INDEXER_TVDB -from sickrage.helper.exceptions import ex class GetOutOfLoop(Exception): @@ -61,25 +54,25 @@ class RarbgProvider(generic.TorrentProvider): 'listing': u'http://torrentapi.org/pubapi_v2.php?mode=list&app_id=sickrage', 'search': u'http://torrentapi.org/pubapi_v2.php?mode=search&app_id=sickrage&search_string={search_string}', 'search_tvdb': u'http://torrentapi.org/pubapi_v2.php?mode=search&app_id=sickrage&search_tvdb={tvdb}&search_string={search_string}', - 'api_spec': u'https://rarbg.com/pubapi/apidocs.txt', - } + 'api_spec': u'https://rarbg.com/pubapi/apidocs.txt'} self.url = self.urls['listing'] self.urlOptions = {'categories': '&category={categories}', - 'seeders': '&min_seeders={min_seeders}', - 'leechers': '&min_leechers={min_leechers}', - 'sorting' : '&sort={sorting}', - 'limit': '&limit={limit}', - 'format': '&format={format}', - 'ranked': '&ranked={ranked}', - 'token': '&token={token}', - } + 'seeders': '&min_seeders={min_seeders}', + 'leechers': '&min_leechers={min_leechers}', + 'sorting' : '&sort={sorting}', + 'limit': '&limit={limit}', + 'format': '&format={format}', + 'ranked': '&ranked={ranked}', + 'token': '&token={token}'} self.defaultOptions = self.urlOptions['categories'].format(categories='tv') + \ self.urlOptions['limit'].format(limit='100') + \ self.urlOptions['format'].format(format='json_extended') + self.proper_strings = ['{{PROPER|REPACK}}'] + self.next_request = datetime.datetime.now() self.headers.update({'User-Agent': USER_AGENT}) @@ -258,34 +251,6 @@ class RarbgProvider(generic.TorrentProvider): return results - def findPropers(self, search_date=datetime.datetime.today()): - - results = [] - - myDB = db.DBConnection() - sqlResults = myDB.select( - 'SELECT s.show_name, e.showid, e.season, e.episode, e.status, e.airdate FROM tv_episodes AS e' + - ' INNER JOIN tv_shows AS s ON (e.showid = s.indexer_id)' + - ' WHERE e.airdate >= ' + str(search_date.toordinal()) + - ' AND (e.status IN (' + ','.join([str(x) for x in Quality.DOWNLOADED]) + ')' + - ' OR (e.status IN (' + ','.join([str(x) for x in Quality.SNATCHED]) + ')))' - ) - - if not sqlResults: - return [] - - for sqlshow in sqlResults: - self.show = helpers.findCertainShow(sickbeard.showList, int(sqlshow["showid"])) - if self.show: - curEp = self.show.getEpisode(int(sqlshow["season"]), int(sqlshow["episode"])) - searchString = self._get_episode_search_strings(curEp, add_string='PROPER|REPACK') - - for item in self._doSearch(searchString[0]): - title, url = self._get_title_and_url(item) - results.append(classes.Proper(title, url, datetime.datetime.today(), self.show)) - - return results - def seedRatio(self): return self.ratio diff --git a/sickbeard/providers/rsstorrent.py b/sickbeard/providers/rsstorrent.py index b5c45bc01b240527584c41d40802aaa28f75d310..436e91c3f056c0e9e3ae9394ba64aa0b6072756d 100644 --- a/sickbeard/providers/rsstorrent.py +++ b/sickbeard/providers/rsstorrent.py @@ -1,6 +1,6 @@ # Author: Mr_Orange # -# This file is part of SickRage. +# This file is part of SickRage. # # SickRage is free software: you can redistribute it and/or modify # it under the terms of the GNU General Public License as published by @@ -37,11 +37,10 @@ class TorrentRssProvider(generic.TorrentProvider): generic.TorrentProvider.__init__(self, name) self.cache = TorrentRssCache(self) - self.urls = {'base_url': re.sub('\/$', '', url)} + self.urls = {'base_url': re.sub(r'\/$', '', url)} self.url = self.urls['base_url'] - self.enabled = True self.ratio = None self.supportsBacklog = False diff --git a/sickbeard/providers/scc.py b/sickbeard/providers/scc.py index b6fbd97e0753ed0f1b6b9b8efcdc8d31cbaf5beb..fd865d4721aaae342695d3fa4b66ac2adf514938 100644 --- a/sickbeard/providers/scc.py +++ b/sickbeard/providers/scc.py @@ -18,23 +18,15 @@ # along with SickRage. If not, see <http://www.gnu.org/licenses/>. import re -import datetime import time +import urllib import sickbeard -import generic -import urllib -from sickbeard.common import Quality, cpu_presets +from sickbeard.common import cpu_presets from sickbeard import logger from sickbeard import tvcache -from sickbeard import db -from sickbeard import classes -from sickbeard import helpers -from sickbeard import show_name_helpers +from sickbeard.providers import generic from sickbeard.bs4_parser import BS4Parser -from unidecode import unidecode -from sickbeard.helpers import sanitizeSceneName - class SCCProvider(generic.TorrentProvider): @@ -43,7 +35,7 @@ class SCCProvider(generic.TorrentProvider): generic.TorrentProvider.__init__(self, "SceneAccess") self.supportsBacklog = True - self.public = False + self.username = None self.password = None @@ -54,14 +46,13 @@ class SCCProvider(generic.TorrentProvider): self.cache = SCCCache(self) self.urls = {'base_url': 'https://sceneaccess.eu', - 'login': 'https://sceneaccess.eu/login', - 'detail': 'https://www.sceneaccess.eu/details?id=%s', - 'search': 'https://sceneaccess.eu/browse?search=%s&method=1&%s', - 'nonscene': 'https://sceneaccess.eu/nonscene?search=%s&method=1&c44=44&c45=44', - 'foreign': 'https://sceneaccess.eu/foreign?search=%s&method=1&c34=34&c33=33', - 'archive': 'https://sceneaccess.eu/archive?search=%s&method=1&c26=26', - 'download': 'https://www.sceneaccess.eu/%s', - } + 'login': 'https://sceneaccess.eu/login', + 'detail': 'https://www.sceneaccess.eu/details?id=%s', + 'search': 'https://sceneaccess.eu/browse?search=%s&method=1&%s', + 'nonscene': 'https://sceneaccess.eu/nonscene?search=%s&method=1&c44=44&c45=44', + 'foreign': 'https://sceneaccess.eu/foreign?search=%s&method=1&c34=34&c33=33', + 'archive': 'https://sceneaccess.eu/archive?search=%s&method=1&c26=26', + 'download': 'https://www.sceneaccess.eu/%s'} self.url = self.urls['base_url'] @@ -74,24 +65,23 @@ class SCCProvider(generic.TorrentProvider): login_params = {'username': self.username, 'password': self.password, - 'submit': 'come on in', - } + 'submit': 'come on in'} - response = self.getURL(self.urls['login'], post_data=login_params, timeout=30) + response = self.getURL(self.urls['login'], post_data=login_params, timeout=30) if not response: logger.log(u"Unable to connect to provider", logger.WARNING) return False - if re.search('Username or password incorrect', response) \ - or re.search('<title>SceneAccess \| Login</title>', response): + if re.search(r'Username or password incorrect', response) \ + or re.search(r'<title>SceneAccess \| Login</title>', response): logger.log(u"Invalid username or password. Check your settings", logger.WARNING) return False return True def _isSection(self, section, text): - title = '<title>.+? \| %s</title>' % section + title = r'<title>.+? \| %s</title>' % section return re.search(title, text, re.IGNORECASE) def _doSearch(self, search_strings, search_mode='eponly', epcount=0, age=0, epObj=None): @@ -150,13 +140,12 @@ class SCCProvider(generic.TorrentProvider): url = all_urls[0] title = link.string - if re.search('\.\.\.', title): + if re.search(r'\.\.\.', title): data = self.getURL(self.url + "/" + link['href']) if data: with BS4Parser(data) as details_html: title = re.search('(?<=").+(?<!")', details_html.title.string).group(0) download_url = self.urls['download'] % url['href'] - id = int(link['href'].replace('details?id=', '')) seeders = int(result.find('td', attrs={'class': 'ttr_seeders'}).string) leechers = int(result.find('td', attrs={'class': 'ttr_leechers'}).string) #FIXME @@ -186,35 +175,6 @@ class SCCProvider(generic.TorrentProvider): return results - def findPropers(self, search_date=datetime.datetime.today()): - - results = [] - - myDB = db.DBConnection() - sqlResults = myDB.select( - 'SELECT s.show_name, e.showid, e.season, e.episode, e.status, e.airdate FROM tv_episodes AS e' + - ' INNER JOIN tv_shows AS s ON (e.showid = s.indexer_id)' + - ' WHERE e.airdate >= ' + str(search_date.toordinal()) + - ' AND (e.status IN (' + ','.join([str(x) for x in Quality.DOWNLOADED]) + ')' + - ' OR (e.status IN (' + ','.join([str(x) for x in Quality.SNATCHED]) + ')))' - ) - - if not sqlResults: - return [] - - for sqlshow in sqlResults: - self.show = helpers.findCertainShow(sickbeard.showList, int(sqlshow["showid"])) - if self.show: - curEp = self.show.getEpisode(int(sqlshow["season"]), int(sqlshow["episode"])) - - searchString = self._get_episode_search_strings(curEp, add_string='PROPER|REPACK') - - for item in self._doSearch(searchString[0]): - title, url = self._get_title_and_url(item) - results.append(classes.Proper(title, url, datetime.datetime.today(), self.show)) - - return results - def seedRatio(self): return self.ratio @@ -228,7 +188,7 @@ class SCCCache(tvcache.TVCache): self.minTime = 20 def _getRSSData(self): - search_params = u'' - return {'entries': self.provider._doSearch(search_params)} + search_strings = {'RSS': ['']} + return {'entries': self.provider._doSearch(search_strings)} provider = SCCProvider() diff --git a/sickbeard/providers/scenetime.py b/sickbeard/providers/scenetime.py index 430e4fcf2865d9e347de25ca1eb2eb8e9ced5ad8..3d50a0daaa9698f5929f69b583df68c081f3fb76 100644 --- a/sickbeard/providers/scenetime.py +++ b/sickbeard/providers/scenetime.py @@ -1,7 +1,7 @@ # Author: Idan Gutman # URL: http://code.google.com/p/sickbeard/ # -# This file is part of SickRage. +# This file is part of SickRage. # # SickRage is free software: you can redistribute it and/or modify # it under the terms of the GNU General Public License as published by @@ -17,21 +17,13 @@ # along with SickRage. If not, see <http://www.gnu.org/licenses/>. import re -import traceback -import datetime -import sickbeard -import generic import urllib -from sickbeard.common import Quality +import traceback + from sickbeard import logger from sickbeard import tvcache -from sickbeard import db -from sickbeard import classes -from sickbeard import helpers -from sickbeard import show_name_helpers +from sickbeard.providers import generic from sickbeard.bs4_parser import BS4Parser -from unidecode import unidecode -from sickbeard.helpers import sanitizeSceneName class SceneTimeProvider(generic.TorrentProvider): @@ -41,9 +33,8 @@ class SceneTimeProvider(generic.TorrentProvider): generic.TorrentProvider.__init__(self, "SceneTime") self.supportsBacklog = True - self.public = False - self.enabled = False + self.username = None self.password = None self.ratio = None @@ -53,11 +44,10 @@ class SceneTimeProvider(generic.TorrentProvider): self.cache = SceneTimeCache(self) self.urls = {'base_url': 'https://www.scenetime.com', - 'login': 'https://www.scenetime.com/takelogin.php', - 'detail': 'https://www.scenetime.com/details.php?id=%s', - 'search': 'https://www.scenetime.com/browse.php?search=%s%s', - 'download': 'https://www.scenetime.com/download.php/%s/%s', - } + 'login': 'https://www.scenetime.com/takelogin.php', + 'detail': 'https://www.scenetime.com/details.php?id=%s', + 'search': 'https://www.scenetime.com/browse.php?search=%s%s', + 'download': 'https://www.scenetime.com/download.php/%s/%s'} self.url = self.urls['base_url'] @@ -69,10 +59,9 @@ class SceneTimeProvider(generic.TorrentProvider): def _doLogin(self): login_params = {'username': self.username, - 'password': self.password - } + 'password': self.password} - response = self.getURL(self.urls['login'], post_data=login_params, timeout=30) + response = self.getURL(self.urls['login'], post_data=login_params, timeout=30) if not response: logger.log(u"Unable to connect to provider", logger.WARNING) return False @@ -99,7 +88,7 @@ class SceneTimeProvider(generic.TorrentProvider): logger.log(u"Search string: %s " % search_string, logger.DEBUG) searchURL = self.urls['search'] % (urllib.quote(search_string), self.categories) - logger.log(u"Search URL: %s" % searchURL, logger.DEBUG) + logger.log(u"Search URL: %s" % searchURL, logger.DEBUG) data = self.getURL(searchURL) if not data: @@ -107,7 +96,7 @@ class SceneTimeProvider(generic.TorrentProvider): try: with BS4Parser(data, features=["html5lib", "permissive"]) as html: - torrent_table = html.select("#torrenttable table"); + torrent_table = html.select("#torrenttable table") torrent_rows = torrent_table[0].select("tr") if torrent_table else [] #Continue only if one Release is found @@ -118,24 +107,21 @@ class SceneTimeProvider(generic.TorrentProvider): # Scenetime apparently uses different number of cells in #torrenttable based # on who you are. This works around that by extracting labels from the first # <tr> and using their index to find the correct download/seeders/leechers td. - labels = [ label.get_text() for label in torrent_rows[0].find_all('td') ] + labels = [label.get_text() for label in torrent_rows[0].find_all('td')] for result in torrent_rows[1:]: cells = result.find_all('td') - link = cells[labels.index('Name')].find('a'); + link = cells[labels.index('Name')].find('a') full_id = link['href'].replace('details.php?id=', '') torrent_id = full_id.split("&")[0] try: title = link.contents[0].get_text() - filename = "%s.torrent" % title.replace(" ", ".") - download_url = self.urls['download'] % (torrent_id, filename) - id = int(torrent_id) seeders = int(cells[labels.index('Seeders')].get_text()) leechers = int(cells[labels.index('Leechers')].get_text()) #FIXME @@ -169,35 +155,6 @@ class SceneTimeProvider(generic.TorrentProvider): return results - def findPropers(self, search_date=datetime.datetime.today()): - - results = [] - - myDB = db.DBConnection() - sqlResults = myDB.select( - 'SELECT s.show_name, e.showid, e.season, e.episode, e.status, e.airdate FROM tv_episodes AS e' + - ' INNER JOIN tv_shows AS s ON (e.showid = s.indexer_id)' + - ' WHERE e.airdate >= ' + str(search_date.toordinal()) + - ' AND (e.status IN (' + ','.join([str(x) for x in Quality.DOWNLOADED]) + ')' + - ' OR (e.status IN (' + ','.join([str(x) for x in Quality.SNATCHED]) + ')))' - ) - - if not sqlResults: - return [] - - for sqlshow in sqlResults: - self.show = helpers.findCertainShow(sickbeard.showList, int(sqlshow["showid"])) - if self.show: - curEp = self.show.getEpisode(int(sqlshow["season"]), int(sqlshow["episode"])) - - searchString = self._get_episode_search_strings(curEp, add_string='PROPER|REPACK') - - for item in self._doSearch(searchString[0]): - title, url = self._get_title_and_url(item) - results.append(classes.Proper(title, url, datetime.datetime.today(), self.show)) - - return results - def seedRatio(self): return self.ratio diff --git a/sickbeard/providers/shazbat.py b/sickbeard/providers/shazbat.py index 19dfea5cde8c671286cff36df86cca6694d627ed..8fffb1dfdbc0c6e2f49fc2fe7d35ab1796f0dc0f 100644 --- a/sickbeard/providers/shazbat.py +++ b/sickbeard/providers/shazbat.py @@ -29,7 +29,7 @@ class ShazbatProvider(generic.TorrentProvider): generic.TorrentProvider.__init__(self, "Shazbat.tv") self.supportsBacklog = False - self.public = False + self.passkey = None self.ratio = None diff --git a/sickbeard/providers/speedcd.py b/sickbeard/providers/speedcd.py index 41321f6e8fd0a7db582172a0e341d7825d3d1ccb..ca93eba6c02f244a5de23224744eb3bcba95e5ce 100644 --- a/sickbeard/providers/speedcd.py +++ b/sickbeard/providers/speedcd.py @@ -17,18 +17,10 @@ # along with SickRage. If not, see <http://www.gnu.org/licenses/>. import re -import datetime -import sickbeard -import generic -from sickbeard.common import Quality from sickbeard import logger from sickbeard import tvcache -from sickbeard import db -from sickbeard import classes -from sickbeard import helpers -from sickbeard import show_name_helpers -from sickbeard.helpers import sanitizeSceneName +from sickbeard.providers import generic class SpeedCDProvider(generic.TorrentProvider): @@ -38,7 +30,6 @@ class SpeedCDProvider(generic.TorrentProvider): generic.TorrentProvider.__init__(self, "Speedcd") self.supportsBacklog = True - self.public = False self.username = None self.password = None @@ -47,29 +38,29 @@ class SpeedCDProvider(generic.TorrentProvider): self.minseed = None self.minleech = None - self.cache = SpeedCDCache(self) - self.urls = {'base_url': 'http://speed.cd/', - 'login': 'http://speed.cd/take_login.php', - 'detail': 'http://speed.cd/t/%s', - 'search': 'http://speed.cd/V3/API/API.php', - 'download': 'http://speed.cd/download.php?torrent=%s', - } + 'login': 'http://speed.cd/take_login.php', + 'detail': 'http://speed.cd/t/%s', + 'search': 'http://speed.cd/V3/API/API.php', + 'download': 'http://speed.cd/download.php?torrent=%s'} self.url = self.urls['base_url'] self.categories = {'Season': {'c14': 1}, 'Episode': {'c2': 1, 'c49': 1}, 'RSS': {'c14': 1, 'c2': 1, 'c49': 1}} + self.proper_strings = ['PROPER', 'REPACK'] + + self.cache = SpeedCDCache(self) + def isEnabled(self): return self.enabled def _doLogin(self): login_params = {'username': self.username, - 'password': self.password - } + 'password': self.password} - response = self.getURL(self.urls['login'], post_data=login_params, timeout=30) + response = self.getURL(self.urls['login'], post_data=login_params, timeout=30) if not response: logger.log(u"Unable to connect to provider", logger.WARNING) return False @@ -106,7 +97,7 @@ class SpeedCDProvider(generic.TorrentProvider): try: torrents = parsedJSON.get('Fs', [])[0].get('Cn', {}).get('torrents', []) - except: + except Exception: continue for torrent in torrents: @@ -143,35 +134,6 @@ class SpeedCDProvider(generic.TorrentProvider): return results - def findPropers(self, search_date=datetime.datetime.today()): - - results = [] - - myDB = db.DBConnection() - sqlResults = myDB.select( - 'SELECT s.show_name, e.showid, e.season, e.episode, e.status, e.airdate FROM tv_episodes AS e' + - ' INNER JOIN tv_shows AS s ON (e.showid = s.indexer_id)' + - ' WHERE e.airdate >= ' + str(search_date.toordinal()) + - ' AND (e.status IN (' + ','.join([str(x) for x in Quality.DOWNLOADED]) + ')' + - ' OR (e.status IN (' + ','.join([str(x) for x in Quality.SNATCHED]) + ')))' - ) - - if not sqlResults: - return [] - - for sqlshow in sqlResults: - self.show = helpers.findCertainShow(sickbeard.showList, int(sqlshow["showid"])) - if self.show: - curEp = self.show.getEpisode(int(sqlshow["season"]), int(sqlshow["episode"])) - - searchString = self._get_episode_search_strings(curEp, add_string='PROPER|REPACK') - - for item in self._doSearch(searchString[0]): - title, url = self._get_title_and_url(item) - results.append(classes.Proper(title, url, datetime.datetime.today(), self.show)) - - return results - def seedRatio(self): return self.ratio diff --git a/sickbeard/providers/strike.py b/sickbeard/providers/strike.py index c2db7b41296dcee6a68908634751704bdaaa4167..0b339fdc9726097b2a3b9a8d55209d43488854ea 100644 --- a/sickbeard/providers/strike.py +++ b/sickbeard/providers/strike.py @@ -1,7 +1,7 @@ # Author: matigonkas # URL: https://github.com/SiCKRAGETV/sickrage # -# This file is part of SickRage. +# This file is part of SickRage. # # SickRage is free software: you can redistribute it and/or modify # it under the terms of the GNU General Public License as published by @@ -16,14 +16,9 @@ # You should have received a copy of the GNU General Public License # along with SickRage. If not, see <http://www.gnu.org/licenses/>. -import datetime -import generic - from sickbeard import logger from sickbeard import tvcache -from sickbeard import show_name_helpers -from sickbeard.config import naming_ep_type -from sickbeard.helpers import sanitizeSceneName +from sickbeard.providers import generic class STRIKEProvider(generic.TorrentProvider): @@ -53,7 +48,7 @@ class STRIKEProvider(generic.TorrentProvider): logger.log(u"Search string: " + search_string.strip(), logger.DEBUG) searchURL = self.url + "api/v2/torrents/search/?category=TV&phrase=" + search_string - logger.log(u"Search URL: %s" % searchURL, logger.DEBUG) + logger.log(u"Search URL: %s" % searchURL, logger.DEBUG) jdata = self.getURL(searchURL, json=True) if not jdata: logger.log("No data returned from provider", logger.DEBUG) diff --git a/sickbeard/providers/t411.py b/sickbeard/providers/t411.py index 86bbb7384be20fb788b59068ce1b416faabca5da..e1f10c646ad755677c724b9b957ed0f1d1da9820 100644 --- a/sickbeard/providers/t411.py +++ b/sickbeard/providers/t411.py @@ -17,22 +17,13 @@ # You should have received a copy of the GNU General Public License # along with Sick Beard. If not, see <http://www.gnu.org/licenses/>. -import traceback -import re -import datetime import time +import traceback from requests.auth import AuthBase -import sickbeard -import generic -from sickbeard.common import Quality from sickbeard import logger from sickbeard import tvcache -from sickbeard import show_name_helpers -from sickbeard import db -from sickbeard import helpers -from sickbeard import classes -from sickbeard.helpers import sanitizeSceneName +from sickbeard.providers import generic class T411Provider(generic.TorrentProvider): @@ -40,8 +31,7 @@ class T411Provider(generic.TorrentProvider): generic.TorrentProvider.__init__(self, "T411") self.supportsBacklog = True - self.public = False - self.enabled = False + self.username = None self.password = None self.ratio = None @@ -54,8 +44,7 @@ class T411Provider(generic.TorrentProvider): 'search': 'https://api.t411.in/torrents/search/%s?cid=%s&limit=100', 'rss': 'https://api.t411.in/torrents/top/today', 'login_page': 'https://api.t411.in/auth', - 'download': 'https://api.t411.in/torrents/download/%s', - } + 'download': 'https://api.t411.in/torrents/download/%s'} self.url = self.urls['base_url'] @@ -172,36 +161,6 @@ class T411Provider(generic.TorrentProvider): return results - def findPropers(self, search_date=datetime.datetime.today()): - - results = [] - - myDB = db.DBConnection() - sqlResults = myDB.select( - 'SELECT s.show_name, e.showid, e.season, e.episode, e.status, e.airdate FROM tv_episodes AS e' + - ' INNER JOIN tv_shows AS s ON (e.showid = s.indexer_id)' + - ' WHERE e.airdate >= ' + str(search_date.toordinal()) + - ' AND (e.status IN (' + ','.join([str(x) for x in Quality.DOWNLOADED]) + ')' + - ' OR (e.status IN (' + ','.join([str(x) for x in Quality.SNATCHED]) + ')))' - ) - - if not sqlResults: - return [] - - for sqlshow in sqlResults: - self.show = helpers.findCertainShow(sickbeard.showList, int(sqlshow["showid"])) - if self.show: - curEp = self.show.getEpisode(int(sqlshow["season"]), int(sqlshow["episode"])) - searchString = self._get_episode_search_strings(curEp, add_string='PROPER|REPACK') - - searchResults = self._doSearch(searchString[0]) - for item in searchResults: - title, url = self._get_title_and_url(item) - if title and url: - results.append(classes.Proper(title, url, datetime.datetime.today(), self.show)) - - return results - def seedRatio(self): return self.ratio diff --git a/sickbeard/providers/thepiratebay.py b/sickbeard/providers/thepiratebay.py index 07b1d7db652cd246e170f9ca16cb5894541617d3..81490d0a4abded86f7ad0ca2d1b0ddcd1ef35af5 100644 --- a/sickbeard/providers/thepiratebay.py +++ b/sickbeard/providers/thepiratebay.py @@ -19,19 +19,12 @@ from __future__ import with_statement import re -import datetime from urllib import urlencode -import sickbeard -from sickbeard.providers import generic -from sickbeard.common import Quality -from sickbeard.common import USER_AGENT -from sickbeard import db -from sickbeard import classes from sickbeard import logger from sickbeard import tvcache -from sickbeard import helpers -from sickbeard.show_name_helpers import allPossibleShowNames, sanitizeSceneName +from sickbeard.providers import generic +from sickbeard.common import USER_AGENT class ThePirateBayProvider(generic.TorrentProvider): @@ -147,30 +140,6 @@ class ThePirateBayProvider(generic.TorrentProvider): size = size * 1024**4 return size - def findPropers(self, search_date=datetime.datetime.today()-datetime.timedelta(days=1)): - - results = [] - - myDB = db.DBConnection() - sqlResults = myDB.select( - 'SELECT s.show_name, e.showid, e.season, e.episode, e.status, e.airdate FROM tv_episodes AS e' + - ' INNER JOIN tv_shows AS s ON (e.showid = s.indexer_id)' + - ' WHERE e.airdate >= ' + str(search_date.toordinal()) + - ' AND (e.status IN (' + ','.join([str(x) for x in Quality.DOWNLOADED]) + ')' + - ' OR (e.status IN (' + ','.join([str(x) for x in Quality.SNATCHED]) + ')))' - ) - - for sqlshow in sqlResults or []: - show = helpers.findCertainShow(sickbeard.showList, int(sqlshow["showid"])) - if show: - curEp = show.getEpisode(int(sqlshow["season"]), int(sqlshow["episode"])) - searchStrings = self._get_episode_search_strings(curEp, add_string='PROPER|REPACK') - for item in self._doSearch(searchStrings[0]): - title, url = self._get_title_and_url(item) - results.append(classes.Proper(title, url, search_date, show)) - - return results - def seedRatio(self): return self.ratio diff --git a/sickbeard/providers/titansoftv.py b/sickbeard/providers/titansoftv.py index c9726309acae903057fecb3aed2f038fe98443dc..8e6d254c624b8f8eb1c178aaf643cc8587c13ad9 100644 --- a/sickbeard/providers/titansoftv.py +++ b/sickbeard/providers/titansoftv.py @@ -30,7 +30,7 @@ class TitansOfTVProvider(generic.TorrentProvider): def __init__(self): generic.TorrentProvider.__init__(self, 'TitansOfTV') self.supportsBacklog = True - self.public = False + self.supportsAbsoluteNumbering = True self.api_key = None self.ratio = None diff --git a/sickbeard/providers/tntvillage.py b/sickbeard/providers/tntvillage.py index c31ea2524d5e91d290faebe4cb881901c768f6ff..534678e35d8726c350b8444d2a078db0404ccbad 100644 --- a/sickbeard/providers/tntvillage.py +++ b/sickbeard/providers/tntvillage.py @@ -1,7 +1,7 @@ # Author: Giovanni Borri # Modified by gborri, https://github.com/gborri for TNTVillage # -# This file is part of SickRage. +# This file is part of SickRage. # # SickRage is free software: you can redistribute it and/or modify # it under the terms of the GNU General Public License as published by @@ -18,49 +18,41 @@ import re import traceback -import datetime -import sickbeard -import generic from sickbeard.common import Quality from sickbeard import logger from sickbeard import tvcache from sickbeard import db -from sickbeard import classes -from sickbeard import helpers -from sickbeard import show_name_helpers + +from sickbeard.providers import generic from sickbeard.bs4_parser import BS4Parser -from unidecode import unidecode -from sickbeard.helpers import sanitizeSceneName from sickbeard.name_parser.parser import NameParser, InvalidNameException, InvalidShowException from sickrage.helper.exceptions import AuthException -category_excluded = { - 'Sport' : 22, - 'Teatro' : 23, - 'Video Musicali' : 21, - 'Film' : 4, - 'Musica' : 2, - 'Students Releases' : 13, - 'E Books' : 3, - 'Linux' : 6, - 'Macintosh' : 9, - 'Windows Software' : 10, - 'Pc Game' : 11, - 'Playstation 2' : 12, - 'Wrestling' : 24, - 'Varie' : 25, - 'Xbox' : 26, - 'Immagini sfondi' : 27, - 'Altri Giochi' : 28, - 'Fumetteria' : 30, - 'Trash' : 31, - 'PlayStation 1' : 32, - 'PSP Portable' : 33, - 'A Book' : 34, - 'Podcast' : 35, - 'Edicola' : 36, - 'Mobile' : 37, - } +category_excluded = {'Sport' : 22, + 'Teatro' : 23, + 'Video Musicali' : 21, + 'Film' : 4, + 'Musica' : 2, + 'Students Releases' : 13, + 'E Books' : 3, + 'Linux' : 6, + 'Macintosh' : 9, + 'Windows Software' : 10, + 'Pc Game' : 11, + 'Playstation 2' : 12, + 'Wrestling' : 24, + 'Varie' : 25, + 'Xbox' : 26, + 'Immagini sfondi' : 27, + 'Altri Giochi' : 28, + 'Fumetteria' : 30, + 'Trash' : 31, + 'PlayStation 1' : 32, + 'PSP Portable' : 33, + 'A Book' : 34, + 'Podcast' : 35, + 'Edicola' : 36, + 'Mobile' : 37} class TNTVillageProvider(generic.TorrentProvider): def __init__(self): @@ -68,9 +60,7 @@ class TNTVillageProvider(generic.TorrentProvider): generic.TorrentProvider.__init__(self, "TNTVillage") self.supportsBacklog = True - self.public = False - self.enabled = False self._uid = None self._hash = None self.username = None @@ -83,8 +73,7 @@ class TNTVillageProvider(generic.TorrentProvider): self.minseed = None self.minleech = None - self.hdtext = [ - ' - Versione 720p', + self.hdtext = [' - Versione 720p', ' Versione 720p', ' V 720p', ' V 720', @@ -95,35 +84,34 @@ class TNTVillageProvider(generic.TorrentProvider): ' 720p HEVC', ' Ver 720', ' 720p HEVC', - ' 720p', - ] + ' 720p'] - self.category_dict = { - 'Serie TV' : 29, + self.category_dict = {'Serie TV' : 29, 'Cartoni' : 8, 'Anime' : 7, 'Programmi e Film TV' : 1, 'Documentari' : 14, - 'All' : 0, - } + 'All' : 0} self.urls = {'base_url' : 'http://forum.tntvillage.scambioetico.org', - 'login' : 'http://forum.tntvillage.scambioetico.org/index.php?act=Login&CODE=01', - 'detail' : 'http://forum.tntvillage.scambioetico.org/index.php?showtopic=%s', - 'search' : 'http://forum.tntvillage.scambioetico.org/?act=allreleases&%s', - 'search_page' : 'http://forum.tntvillage.scambioetico.org/?act=allreleases&st={0}&{1}', - 'download' : 'http://forum.tntvillage.scambioetico.org/index.php?act=Attach&type=post&id=%s', - } - - self.sub_string = ['sub', 'softsub'] + 'login' : 'http://forum.tntvillage.scambioetico.org/index.php?act=Login&CODE=01', + 'detail' : 'http://forum.tntvillage.scambioetico.org/index.php?showtopic=%s', + 'search' : 'http://forum.tntvillage.scambioetico.org/?act=allreleases&%s', + 'search_page' : 'http://forum.tntvillage.scambioetico.org/?act=allreleases&st={0}&{1}', + 'download' : 'http://forum.tntvillage.scambioetico.org/index.php?act=Attach&type=post&id=%s'} self.url = self.urls['base_url'] - self.cache = TNTVillageCache(self) + self.cookies = None + + self.sub_string = ['sub', 'softsub'] + + self.proper_strings = ['PROPER', 'REPACK'] self.categories = "cat=29" - self.cookies = None + self.cache = TNTVillageCache(self) + def isEnabled(self): return self.enabled @@ -140,10 +128,9 @@ class TNTVillageProvider(generic.TorrentProvider): login_params = {'UserName': self.username, 'PassWord': self.password, 'CookieDate': 0, - 'submit': 'Connettiti al Forum', - } + 'submit': 'Connettiti al Forum'} - response = self.getURL(self.urls['login'], post_data=login_params, timeout=30) + response = self.getURL(self.urls['login'], post_data=login_params, timeout=30) if not response: logger.log(u"Unable to connect to provider", logger.WARNING) return False @@ -180,18 +167,18 @@ class TNTVillageProvider(generic.TorrentProvider): return quality_string - def _episodeQuality(self,torrent_rows): + def _episodeQuality(self, torrent_rows): """ Return The quality from the scene episode HTML row. """ - file_quality='' + file_quality = '' img_all = (torrent_rows.find_all('td'))[1].find_all('img') if len(img_all) > 0: for img_type in img_all: try: - file_quality = file_quality + " " + img_type['src'].replace("style_images/mkportal-636/","").replace(".gif","").replace(".png","") + file_quality = file_quality + " " + img_type['src'].replace("style_images/mkportal-636/", "").replace(".gif", "").replace(".png", "") except Exception: logger.log(u"Failed parsing quality. Traceback: %s" % traceback.format_exc(), logger.ERROR) @@ -202,7 +189,7 @@ class TNTVillageProvider(generic.TorrentProvider): checkName = lambda list, func: func([re.search(x, file_quality, re.I) for x in list]) dvdOptions = checkName(["dvd", "dvdrip", "dvdmux", "DVD9", "DVD5"], any) - bluRayOptions = checkName(["BD","BDmux", "BDrip", "BRrip", "Bluray"], any) + bluRayOptions = checkName(["BD", "BDmux", "BDrip", "BRrip", "Bluray"], any) sdOptions = checkName(["h264", "divx", "XviD", "tv", "TVrip", "SATRip", "DTTrip", "Mpeg2"], any) hdOptions = checkName(["720p"], any) fullHD = checkName(["1080p", "fullHD"], any) @@ -254,7 +241,7 @@ class TNTVillageProvider(generic.TorrentProvider): italian = True return italian - + def _is_english(self, torrent_rows): name = str(torrent_rows.find_all('td')[1].find('b').find('span')) @@ -274,14 +261,14 @@ class TNTVillageProvider(generic.TorrentProvider): myParser = NameParser(tryIndexers=True, trySceneExceptions=True) parse_result = myParser.parse(name) except InvalidNameException: - logger.log(u"Unable to parse the filename %s into a valid episode" % title, logger.DEBUG) + logger.log(u"Unable to parse the filename %s into a valid episode" % name, logger.DEBUG) return False except InvalidShowException: - logger.log(u"Unable to parse the filename %s into a valid show" % title, logger.DEBUG) + logger.log(u"Unable to parse the filename %s into a valid show" % name, logger.DEBUG) return False myDB = db.DBConnection() - sql_selection="select count(*) as count from tv_episodes where showid = ? and season = ?" + sql_selection = "select count(*) as count from tv_episodes where showid = ? and season = ?" episodes = myDB.select(sql_selection, [parse_result.show.indexerid, parse_result.season_number]) if int(episodes[0]['count']) == len(parse_result.episode_numbers): return True @@ -303,28 +290,28 @@ class TNTVillageProvider(generic.TorrentProvider): if mode == 'RSS': self.page = 2 - last_page=0 - y=int(self.page) + last_page = 0 + y = int(self.page) if search_string == '': continue search_string = str(search_string).replace('.', ' ') - for x in range(0,y): - z=x*20 + for x in range(0, y): + z = x*20 if last_page: break if mode != 'RSS': - searchURL = (self.urls['search_page'] + '&filter={2}').format(z,self.categories,search_string) + searchURL = (self.urls['search_page'] + '&filter={2}').format(z, self.categories, search_string) else: - searchURL = self.urls['search_page'].format(z,self.categories) + searchURL = self.urls['search_page'].format(z, self.categories) if mode != 'RSS': logger.log(u"Search string: %s " % search_string, logger.DEBUG) - logger.log(u"Search URL: %s" % searchURL, logger.DEBUG) + logger.log(u"Search URL: %s" % searchURL, logger.DEBUG) data = self.getURL(searchURL) if not data: logger.log("No data returned from provider", logger.DEBUG) @@ -332,25 +319,24 @@ class TNTVillageProvider(generic.TorrentProvider): try: with BS4Parser(data, features=["html5lib", "permissive"]) as html: - torrent_table = html.find('table', attrs = {'class' : 'copyright'}) + torrent_table = html.find('table', attrs={'class' : 'copyright'}) torrent_rows = torrent_table.find_all('tr') if torrent_table else [] #Continue only if one Release is found - if len(torrent_rows)<3: + if len(torrent_rows) < 3: logger.log(u"Data returned from provider does not contain any torrents", logger.DEBUG) - last_page=1 + last_page = 1 continue if len(torrent_rows) < 42: - last_page=1 + last_page = 1 for result in torrent_table.find_all('tr')[2:]: try: link = result.find('td').find('a') title = link.string - id = ((result.find_all('td')[8].find('a'))['href'])[-8:] - download_url = self.urls['download'] % (id) + download_url = self.urls['download'] % result.find_all('td')[8].find('a')['href'][-8:] leechers = result.find_all('td')[3].find_all('td')[1].text leechers = int(leechers.strip('[]')) seeders = result.find_all('td')[3].find_all('td')[2].text @@ -363,7 +349,7 @@ class TNTVillageProvider(generic.TorrentProvider): filename_qt = self._reverseQuality(self._episodeQuality(result)) for text in self.hdtext: title1 = title - title = title.replace(text,filename_qt) + title = title.replace(text, filename_qt) if title != title1: break @@ -416,35 +402,6 @@ class TNTVillageProvider(generic.TorrentProvider): return results - def findPropers(self, search_date=datetime.datetime.today()): - - results = [] - - myDB = db.DBConnection() - sqlResults = myDB.select( - 'SELECT s.show_name, e.showid, e.season, e.episode, e.status, e.airdate FROM tv_episodes AS e' + - ' INNER JOIN tv_shows AS s ON (e.showid = s.indexer_id)' + - ' WHERE e.airdate >= ' + str(search_date.toordinal()) + - ' AND (e.status IN (' + ','.join([str(x) for x in Quality.DOWNLOADED]) + ')' + - ' OR (e.status IN (' + ','.join([str(x) for x in Quality.SNATCHED]) + ')))' - ) - - if not sqlResults: - return [] - - for sqlshow in sqlResults: - self.show = curshow = helpers.findCertainShow(sickbeard.showList, int(sqlshow["showid"])) - if not self.show: continue - curEp = curshow.getEpisode(int(sqlshow["season"]), int(sqlshow["episode"])) - - searchString = self._get_episode_search_strings(curEp, add_string='PROPER|REPACK') - - for item in self._doSearch(searchString[0]): - title, url = self._get_title_and_url(item) - results.append(classes.Proper(title, url, datetime.datetime.today(), self.show)) - - return results - def seedRatio(self): return self.ratio diff --git a/sickbeard/providers/tokyotoshokan.py b/sickbeard/providers/tokyotoshokan.py index 6adee66ba007e809774077d17f144ae0766554b5..d1d909528733182789b1585171ee954fa3d25a14 100644 --- a/sickbeard/providers/tokyotoshokan.py +++ b/sickbeard/providers/tokyotoshokan.py @@ -19,12 +19,9 @@ import urllib import traceback -import generic - -from sickbeard import show_name_helpers from sickbeard import logger -from sickbeard.common import Quality from sickbeard import tvcache +from sickbeard.providers import generic from sickbeard import show_name_helpers from sickbeard.bs4_parser import BS4Parser @@ -51,9 +48,6 @@ class TokyoToshokanProvider(generic.TorrentProvider): def seedRatio(self): return self.ratio - def findSearchResults(self, show, episodes, search_mode, manualSearch=False, downCurQuality=False): - return generic.TorrentProvider.findSearchResults(self, show, episodes, search_mode, manualSearch, downCurQuality) - def _get_season_search_strings(self, ep_obj): return [x.replace('.', ' ') for x in show_name_helpers.makeSceneSeasonSearchString(self.show, ep_obj)] diff --git a/sickbeard/providers/torrentbytes.py b/sickbeard/providers/torrentbytes.py index 98fa2c808ce1fd0e96a19bfa7016a2dedb2cb214..461009f78beee3017751d6cb898e6441682e757f 100644 --- a/sickbeard/providers/torrentbytes.py +++ b/sickbeard/providers/torrentbytes.py @@ -1,7 +1,7 @@ # Author: Idan Gutman # URL: http://code.google.com/p/sickbeard/ # -# This file is part of SickRage. +# This file is part of SickRage. # # SickRage is free software: you can redistribute it and/or modify # it under the terms of the GNU General Public License as published by @@ -17,21 +17,13 @@ # along with SickRage. If not, see <http://www.gnu.org/licenses/>. import re -import traceback -import datetime -import sickbeard -import generic import urllib -from sickbeard.common import Quality +import traceback + from sickbeard import logger from sickbeard import tvcache -from sickbeard import db -from sickbeard import classes -from sickbeard import helpers -from sickbeard import show_name_helpers +from sickbeard.providers import generic from sickbeard.bs4_parser import BS4Parser -from unidecode import unidecode -from sickbeard.helpers import sanitizeSceneName class TorrentBytesProvider(generic.TorrentProvider): @@ -41,28 +33,27 @@ class TorrentBytesProvider(generic.TorrentProvider): generic.TorrentProvider.__init__(self, "TorrentBytes") self.supportsBacklog = True - self.public = False - self.enabled = False self.username = None self.password = None self.ratio = None self.minseed = None self.minleech = None - self.cache = TorrentBytesCache(self) - self.urls = {'base_url': 'https://www.torrentbytes.net', - 'login': 'https://www.torrentbytes.net/takelogin.php', - 'detail': 'https://www.torrentbytes.net/details.php?id=%s', - 'search': 'https://www.torrentbytes.net/browse.php?search=%s%s', - 'download': 'https://www.torrentbytes.net/download.php?id=%s&name=%s', - } + 'login': 'https://www.torrentbytes.net/takelogin.php', + 'detail': 'https://www.torrentbytes.net/details.php?id=%s', + 'search': 'https://www.torrentbytes.net/browse.php?search=%s%s', + 'download': 'https://www.torrentbytes.net/download.php?id=%s&name=%s'} self.url = self.urls['base_url'] self.categories = "&c41=1&c33=1&c38=1&c32=1&c37=1" + self.proper_strings = ['PROPER', 'REPACK'] + + self.cache = TorrentBytesCache(self) + def isEnabled(self): return self.enabled @@ -70,10 +61,9 @@ class TorrentBytesProvider(generic.TorrentProvider): login_params = {'username': self.username, 'password': self.password, - 'login': 'Log in!' - } + 'login': 'Log in!'} - response = self.getURL(self.urls['login'], post_data=login_params, timeout=30) + response = self.getURL(self.urls['login'], post_data=login_params, timeout=30) if not response: logger.log(u"Unable to connect to provider", logger.WARNING) return False @@ -100,7 +90,7 @@ class TorrentBytesProvider(generic.TorrentProvider): logger.log(u"Search string: %s " % search_string, logger.DEBUG) searchURL = self.urls['search'] % (urllib.quote(search_string.encode('utf-8')), self.categories) - logger.log(u"Search URL: %s" % searchURL, logger.DEBUG) + logger.log(u"Search URL: %s" % searchURL, logger.DEBUG) data = self.getURL(searchURL) if not data: @@ -130,7 +120,6 @@ class TorrentBytesProvider(generic.TorrentProvider): else: title = link.contents[0] download_url = self.urls['download'] % (torrent_id, link.contents[0]) - id = int(torrent_id) seeders = int(cells[8].find('span').contents[0]) leechers = int(cells[9].find('span').contents[0]) #FIXME @@ -163,35 +152,6 @@ class TorrentBytesProvider(generic.TorrentProvider): return results - def findPropers(self, search_date=datetime.datetime.today()): - - results = [] - - myDB = db.DBConnection() - sqlResults = myDB.select( - 'SELECT s.show_name, e.showid, e.season, e.episode, e.status, e.airdate FROM tv_episodes AS e' + - ' INNER JOIN tv_shows AS s ON (e.showid = s.indexer_id)' + - ' WHERE e.airdate >= ' + str(search_date.toordinal()) + - ' AND (e.status IN (' + ','.join([str(x) for x in Quality.DOWNLOADED]) + ')' + - ' OR (e.status IN (' + ','.join([str(x) for x in Quality.SNATCHED]) + ')))' - ) - - if not sqlResults: - return [] - - for sqlshow in sqlResults: - self.show = helpers.findCertainShow(sickbeard.showList, int(sqlshow["showid"])) - if self.show: - curEp = self.show.getEpisode(int(sqlshow["season"]), int(sqlshow["episode"])) - - searchString = self._get_episode_search_strings(curEp, add_string='PROPER|REPACK') - - for item in self._doSearch(searchString[0]): - title, url = self._get_title_and_url(item) - results.append(classes.Proper(title, url, datetime.datetime.today(), self.show)) - - return results - def seedRatio(self): return self.ratio diff --git a/sickbeard/providers/torrentday.py b/sickbeard/providers/torrentday.py index 9f1dda8f1ee226fa64ed785859f6cbf7a02e3cfc..82b2cf235dca7cd1a2cdeb80de606b814f1ecdfd 100644 --- a/sickbeard/providers/torrentday.py +++ b/sickbeard/providers/torrentday.py @@ -16,19 +16,10 @@ # along with SickRage. If not, see <http://www.gnu.org/licenses/>. import re -import datetime -import sickbeard -import generic -from sickbeard.common import Quality +import requests from sickbeard import logger from sickbeard import tvcache -from sickbeard import db -from sickbeard import classes -from sickbeard import helpers -from sickbeard import show_name_helpers -import requests -from sickbeard.helpers import sanitizeSceneName - +from sickbeard.providers import generic class TorrentDayProvider(generic.TorrentProvider): @@ -37,7 +28,7 @@ class TorrentDayProvider(generic.TorrentProvider): generic.TorrentProvider.__init__(self, "TorrentDay") self.supportsBacklog = True - self.public = False + self._uid = None self._hash = None @@ -51,10 +42,9 @@ class TorrentDayProvider(generic.TorrentProvider): self.cache = TorrentDayCache(self) self.urls = {'base_url': 'https://classic.torrentday.com', - 'login': 'https://classic.torrentday.com/torrents/', - 'search': 'https://classic.torrentday.com/V3/API/API.php', - 'download': 'https://classic.torrentday.com/download.php/%s/%s' - } + 'login': 'https://classic.torrentday.com/torrents/', + 'search': 'https://classic.torrentday.com/V3/API/API.php', + 'download': 'https://classic.torrentday.com/download.php/%s/%s'} self.url = self.urls['base_url'] @@ -78,10 +68,9 @@ class TorrentDayProvider(generic.TorrentProvider): login_params = {'username': self.username, 'password': self.password, 'submit.x': 0, - 'submit.y': 0 - } + 'submit.y': 0} - response = self.getURL(self.urls['login'], post_data=login_params, timeout=30) + response = self.getURL(self.urls['login'], post_data=login_params, timeout=30) if not response: logger.log(u"Unable to connect to provider", logger.WARNING) return False @@ -96,10 +85,9 @@ class TorrentDayProvider(generic.TorrentProvider): self._hash = requests.utils.dict_from_cookiejar(self.session.cookies)['pass'] self.cookies = {'uid': self._uid, - 'pass': self._hash - } + 'pass': self._hash} return True - except: + except Exception: pass logger.log(u"Unable to obtain cookie", logger.WARNING) @@ -110,8 +98,6 @@ class TorrentDayProvider(generic.TorrentProvider): results = [] items = {'Season': [], 'Episode': [], 'RSS': []} - freeleech = '&free=on' if self.freeleech else '' - if not self._doLogin(): return results @@ -137,14 +123,14 @@ class TorrentDayProvider(generic.TorrentProvider): try: torrents = parsedJSON.get('Fs', [])[0].get('Cn', {}).get('torrents', []) - except: + except Exception: logger.log(u"Data returned from provider does not contain any torrents", logger.DEBUG) continue for torrent in torrents: title = re.sub(r"\[.*\=.*\].*\[/.*\]", "", torrent['name']) - download_url = self.urls['download'] % ( torrent['id'], torrent['fname'] ) + download_url = self.urls['download'] % ( torrent['id'], torrent['fname']) seeders = int(torrent['seed']) leechers = int(torrent['leech']) #FIXME @@ -172,35 +158,6 @@ class TorrentDayProvider(generic.TorrentProvider): return results - def findPropers(self, search_date=datetime.datetime.today()): - - results = [] - - myDB = db.DBConnection() - sqlResults = myDB.select( - 'SELECT s.show_name, e.showid, e.season, e.episode, e.status, e.airdate FROM tv_episodes AS e' + - ' INNER JOIN tv_shows AS s ON (e.showid = s.indexer_id)' + - ' WHERE e.airdate >= ' + str(search_date.toordinal()) + - ' AND (e.status IN (' + ','.join([str(x) for x in Quality.DOWNLOADED]) + ')' + - ' OR (e.status IN (' + ','.join([str(x) for x in Quality.SNATCHED]) + ')))' - ) - - if not sqlResults: - return [] - - for sqlshow in sqlResults: - self.show = helpers.findCertainShow(sickbeard.showList, int(sqlshow["showid"])) - if self.show: - curEp = self.show.getEpisode(int(sqlshow["season"]), int(sqlshow["episode"])) - - searchString = self._get_episode_search_strings(curEp, add_string='PROPER|REPACK') - - for item in self._doSearch(searchString[0]): - title, url = self._get_title_and_url(item) - results.append(classes.Proper(title, url, datetime.datetime.today(), self.show)) - - return results - def seedRatio(self): return self.ratio diff --git a/sickbeard/providers/torrentleech.py b/sickbeard/providers/torrentleech.py index 1f30ebdbef8b4b1f74a1401a48f173fb0879d40a..f08df58bb0d304302db3c63eae6265c5e03f580a 100644 --- a/sickbeard/providers/torrentleech.py +++ b/sickbeard/providers/torrentleech.py @@ -18,21 +18,12 @@ import re import traceback -import datetime import urllib -import sickbeard -import generic -from sickbeard.common import Quality from sickbeard import logger from sickbeard import tvcache -from sickbeard import db -from sickbeard import classes -from sickbeard import helpers -from sickbeard import show_name_helpers +from sickbeard.providers import generic from sickbeard.bs4_parser import BS4Parser -from unidecode import unidecode -from sickbeard.helpers import sanitizeSceneName class TorrentLeechProvider(generic.TorrentProvider): @@ -42,29 +33,28 @@ class TorrentLeechProvider(generic.TorrentProvider): generic.TorrentProvider.__init__(self, "TorrentLeech") self.supportsBacklog = True - self.public = False - self.enabled = False self.username = None self.password = None self.ratio = None self.minseed = None self.minleech = None - self.cache = TorrentLeechCache(self) - self.urls = {'base_url': 'https://torrentleech.org/', - 'login': 'https://torrentleech.org/user/account/login/', - 'detail': 'https://torrentleech.org/torrent/%s', - 'search': 'https://torrentleech.org/torrents/browse/index/query/%s/categories/%s', - 'download': 'https://torrentleech.org%s', - 'index': 'https://torrentleech.org/torrents/browse/index/categories/%s', - } + 'login': 'https://torrentleech.org/user/account/login/', + 'detail': 'https://torrentleech.org/torrent/%s', + 'search': 'https://torrentleech.org/torrents/browse/index/query/%s/categories/%s', + 'download': 'https://torrentleech.org%s', + 'index': 'https://torrentleech.org/torrents/browse/index/categories/%s'} self.url = self.urls['base_url'] self.categories = "2,7,26,27,32,34,35" + self.proper_strings = ['PROPER', 'REPACK'] + + self.cache = TorrentLeechCache(self) + def isEnabled(self): return self.enabled @@ -73,10 +63,9 @@ class TorrentLeechProvider(generic.TorrentProvider): login_params = {'username': self.username, 'password': self.password, 'remember_me': 'on', - 'login': 'submit', - } + 'login': 'submit'} - response = self.getURL(self.urls['login'], post_data=login_params, timeout=30) + response = self.getURL(self.urls['login'], post_data=login_params, timeout=30) if not response: logger.log(u"Unable to connect to provider", logger.WARNING) return False @@ -127,7 +116,6 @@ class TorrentLeechProvider(generic.TorrentProvider): url = result.find('td', attrs={'class': 'quickdownload'}).find('a') title = link.string download_url = self.urls['download'] % url['href'] - id = int(link['href'].replace('/torrent/', '')) seeders = int(result.find('td', attrs={'class': 'seeders'}).string) leechers = int(result.find('td', attrs={'class': 'leechers'}).string) #FIXME @@ -160,35 +148,6 @@ class TorrentLeechProvider(generic.TorrentProvider): return results - def findPropers(self, search_date=datetime.datetime.today()): - - results = [] - - myDB = db.DBConnection() - sqlResults = myDB.select( - 'SELECT s.show_name, e.showid, e.season, e.episode, e.status, e.airdate FROM tv_episodes AS e' + - ' INNER JOIN tv_shows AS s ON (e.showid = s.indexer_id)' + - ' WHERE e.airdate >= ' + str(search_date.toordinal()) + - ' AND (e.status IN (' + ','.join([str(x) for x in Quality.DOWNLOADED]) + ')' + - ' OR (e.status IN (' + ','.join([str(x) for x in Quality.SNATCHED]) + ')))' - ) - - if not sqlResults: - return [] - - for sqlshow in sqlResults: - self.show = helpers.findCertainShow(sickbeard.showList, int(sqlshow["showid"])) - if self.show: - curEp = self.show.getEpisode(int(sqlshow["season"]), int(sqlshow["episode"])) - - searchString = self._get_episode_search_strings(curEp, add_string='PROPER|REPACK') - - for item in self._doSearch(searchString[0]): - title, url = self._get_title_and_url(item) - results.append(classes.Proper(title, url, datetime.datetime.today(), self.show)) - - return results - def seedRatio(self): return self.ratio diff --git a/sickbeard/providers/torrentproject.py b/sickbeard/providers/torrentproject.py index 8f45fa1ec172e107efbd5c442fab50809a9fd881..bf3361d4cb1341baff3224f096c1e5e482ac66ba 100644 --- a/sickbeard/providers/torrentproject.py +++ b/sickbeard/providers/torrentproject.py @@ -16,19 +16,12 @@ # You should have received a copy of the GNU General Public License # along with SickRage. If not, see <http://www.gnu.org/licenses/>. -import datetime -import generic -import json from urllib import quote_plus from sickbeard import logger from sickbeard import tvcache -from sickbeard import show_name_helpers -from sickbeard import db -from sickbeard.common import WANTED +from sickbeard.providers import generic from sickbeard.common import USER_AGENT -from sickbeard.config import naming_ep_type -from sickbeard.helpers import sanitizeSceneName class TORRENTPROJECTProvider(generic.TorrentProvider): @@ -59,7 +52,7 @@ class TORRENTPROJECTProvider(generic.TorrentProvider): if mode != 'RSS': logger.log(u"Search string: %s " % search_string, logger.DEBUG) - searchURL = self.urls['api'] + "?s=%s&out=json" % quote_plus(search_string.encode('utf-8')) + searchURL = self.urls['api'] + "?s=%s&out=json&filter=2101" % quote_plus(search_string.encode('utf-8')) logger.log(u"Search URL: %s" % searchURL, logger.DEBUG) torrents = self.getURL(searchURL, json=True) if not (torrents and "total_found" in torrents and int(torrents["total_found"]) > 0): @@ -79,19 +72,19 @@ class TORRENTPROJECTProvider(generic.TorrentProvider): logger.log("Torrent doesn't meet minimum seeds & leechers not selecting : %s" % title, logger.DEBUG) continue - hash = torrents[i]["torrent_hash"] + t_hash = torrents[i]["torrent_hash"] size = int(torrents[i]["torrent_size"]) - if seeders < 10 : + if seeders < 10: logger.log("Torrent has less than 10 seeds getting dyn trackers: " + title, logger.DEBUG) - trackerUrl = self.urls['api'] + "" + hash + "/trackers_json" + trackerUrl = self.urls['api'] + "" + t_hash + "/trackers_json" jdata = self.getURL(trackerUrl, json=True) - download_url = "magnet:?xt=urn:btih:" + hash + "&dn=" + title + "".join(["&tr=" + s for s in jdata]) + download_url = "magnet:?xt=urn:btih:" + t_hash + "&dn=" + title + "".join(["&tr=" + s for s in jdata]) logger.log("Dyn Magnet: " + download_url, logger.DEBUG) else: #logger.log("Torrent has more than 10 seeds using hard coded trackers", logger.DEBUG) - download_url = "magnet:?xt=urn:btih:" + hash + "&dn=" + title + "&tr=udp://tracker.openbittorrent.com:80&tr=udp://tracker.publicbt.com:80&tr=http://tracker.coppersurfer.tk:6969/announce&tr=http://genesis.1337x.org:1337/announce&tr=http://nemesis.1337x.org/announce&tr=http://erdgeist.org/arts/software/opentracker/announce&tr=http://tracker.ccc.de/announce&tr=http://www.eddie4.nl:6969/announce&tr=http://tracker.leechers-paradise.org:6969/announce" - + download_url = "magnet:?xt=urn:btih:" + t_hash + "&dn=" + title + "&tr=udp://tracker.openbittorrent.com:80&tr=udp://tracker.coppersurfer.tk:6969&tr=udp://open.demonii.com:1337&tr=udp://tracker.leechers-paradise.org:6969&tr=udp://exodus.desync.com:6969" + if not all([title, download_url]): continue @@ -121,7 +114,7 @@ class TORRENTPROJECTCache(tvcache.TVCache): def _getRSSData(self): # no rss for torrentproject afaik,& can't search with empty string # newest results are always > 1 day since added anyways - search_strings = {'RSS': ['']} + # search_strings = {'RSS': ['']} return {'entries': {}} provider = TORRENTPROJECTProvider() diff --git a/sickbeard/providers/transmitthenet.py b/sickbeard/providers/transmitthenet.py index 194535a5185a102c54a14c7be04a24698987a4d4..136cc1f54168cbe280ca67bb855b48ae3dcb74ad 100644 --- a/sickbeard/providers/transmitthenet.py +++ b/sickbeard/providers/transmitthenet.py @@ -15,21 +15,13 @@ import re import traceback -import datetime -import sickbeard -import generic +from urllib import urlencode -from sickbeard.common import Quality from sickbeard import logger from sickbeard import tvcache -from sickbeard import db -from sickbeard import classes -from sickbeard import helpers -from sickbeard import show_name_helpers -from sickbeard.helpers import sanitizeSceneName from sickbeard.bs4_parser import BS4Parser +from sickbeard.providers import generic from sickrage.helper.exceptions import AuthException -from urllib import urlencode class TransmitTheNetProvider(generic.TorrentProvider): @@ -45,7 +37,7 @@ class TransmitTheNetProvider(generic.TorrentProvider): self.url = self.urls['base_url'] self.supportsBacklog = True - self.public = False + self.username = None self.password = None self.ratio = None @@ -133,7 +125,6 @@ class TransmitTheNetProvider(generic.TorrentProvider): title = torrent_row.find('a', {"data-src": True})['data-src'].rsplit('.', 1)[0] download_href = torrent_row.find('img', {"alt": 'Download Torrent'}).findParent()['href'] - id = torrent_row.find('a', {"data-src": True})['href'].split("&id=", 1)[1] seeders = int(torrent_row.findAll('a', {'title': 'Click here to view peers details'})[0].text.strip()) leechers = int(torrent_row.findAll('a', {'title': 'Click here to view peers details'})[1].text.strip()) download_url = self.urls['base_url'] + download_href @@ -155,7 +146,7 @@ class TransmitTheNetProvider(generic.TorrentProvider): items[mode].append(item) - except: + except Exception: logger.log(u"Failed parsing provider. Traceback: %s" % traceback.format_exc(), logger.ERROR) # For each search mode sort all the items by seeders @@ -165,32 +156,6 @@ class TransmitTheNetProvider(generic.TorrentProvider): return results - def findPropers(self, search_date=datetime.datetime.today()): - - results = [] - - myDB = db.DBConnection() - sqlResults = myDB.select( - 'SELECT s.show_name, e.showid, e.season, e.episode, e.status, e.airdate FROM tv_episodes AS e' + - ' INNER JOIN tv_shows AS s ON (e.showid = s.indexer_id)' + - ' WHERE e.airdate >= ' + str(search_date.toordinal()) + - ' AND (e.status IN (' + ','.join([str(x) for x in Quality.DOWNLOADED]) + ')' + - ' OR (e.status IN (' + ','.join([str(x) for x in Quality.SNATCHED]) + ')))' - ) - - for sqlshow in sqlResults or []: - self.show = helpers.findCertainShow(sickbeard.showList, int(sqlshow["showid"])) - if self.show: - curEp = self.show.getEpisode(int(sqlshow["season"]), int(sqlshow["episode"])) - - searchString = self._get_episode_search_strings(curEp, add_string='PROPER|REPACK') - - for item in self._doSearch(searchString[0]): - title, url = self._get_title_and_url(item) - results.append(classes.Proper(title, url, datetime.datetime.today(), self.show)) - - return results - def seedRatio(self): return self.ratio diff --git a/sickbeard/providers/tvchaosuk.py b/sickbeard/providers/tvchaosuk.py index 6842d68a58cbc2a019e500932568d6ff3a7f9339..30dbc2f9aaa5242598e51602e5c954949e456b76 100644 --- a/sickbeard/providers/tvchaosuk.py +++ b/sickbeard/providers/tvchaosuk.py @@ -1,4 +1,4 @@ -# This file is part of SickRage. +# This file is part of SickRage. # # SickRage is free software: you can redistribute it and/or modify # it under the terms of the GNU General Public License as published by @@ -14,40 +14,31 @@ # along with SickRage. If not, see <http://www.gnu.org/licenses/>. import re -import datetime -import sickbeard -import generic +#from urllib import urlencode -from sickbeard.common import Quality +import sickbeard from sickbeard import logger from sickbeard import tvcache -from sickbeard import db -from sickbeard import classes -from sickbeard import helpers from sickbeard import show_name_helpers from sickbeard.helpers import sanitizeSceneName +from sickbeard.providers import generic from sickbeard.bs4_parser import BS4Parser from sickrage.helper.exceptions import AuthException -from urllib import urlencode - class TVChaosUKProvider(generic.TorrentProvider): def __init__(self): generic.TorrentProvider.__init__(self, 'TvChaosUK') - self.urls = { - 'base_url': 'https://tvchaosuk.com/', - 'login': 'https://tvchaosuk.com/takelogin.php', - 'index': 'https://tvchaosuk.com/index.php', - 'search': 'https://tvchaosuk.com/browse.php' - } + self.urls = {'base_url': 'https://tvchaosuk.com/', + 'login': 'https://tvchaosuk.com/takelogin.php', + 'index': 'https://tvchaosuk.com/index.php', + 'search': 'https://tvchaosuk.com/browse.php'} self.url = self.urls['base_url'] self.supportsBacklog = True - self.public = False - self.enabled = False + self.username = None self.password = None self.ratio = None @@ -148,7 +139,7 @@ class TVChaosUKProvider(generic.TorrentProvider): self.search_params['keywords'] = search_string.strip() data = self.getURL(self.urls['search'], params=self.search_params) - url_searched = self.urls['search'] + '?' + urlencode(self.search_params) + #url_searched = self.urls['search'] + '?' + urlencode(self.search_params) if not data: logger.log("No data returned from provider", logger.DEBUG) @@ -193,7 +184,7 @@ class TVChaosUKProvider(generic.TorrentProvider): items[mode].append(item) - except: + except Exception: continue #For each search mode sort all the items by seeders if available @@ -203,32 +194,6 @@ class TVChaosUKProvider(generic.TorrentProvider): return results - def findPropers(self, search_date=datetime.datetime.today()): - - results = [] - - myDB = db.DBConnection() - sqlResults = myDB.select( - 'SELECT s.show_name, e.showid, e.season, e.episode, e.status, e.airdate FROM tv_episodes AS e' + - ' INNER JOIN tv_shows AS s ON (e.showid = s.indexer_id)' + - ' WHERE e.airdate >= ' + str(search_date.toordinal()) + - ' AND (e.status IN (' + ','.join([str(x) for x in Quality.DOWNLOADED]) + ')' + - ' OR (e.status IN (' + ','.join([str(x) for x in Quality.SNATCHED]) + ')))' - ) - - for sqlshow in sqlResults or []: - self.show = helpers.findCertainShow(sickbeard.showList, int(sqlshow['showid'])) - if self.show: - curEp = self.show.getEpisode(int(sqlshow['season']), int(sqlshow['episode'])) - - searchString = self._get_episode_search_strings(curEp, add_string='PROPER|REPACK') - - for item in self._doSearch(searchString[0]): - title, url = self._get_title_and_url(item) - results.append(classes.Proper(title, url, datetime.datetime.today(), self.show)) - - return results - def seedRatio(self): return self.ratio diff --git a/sickbeard/providers/xthor.py b/sickbeard/providers/xthor.py index bb431f14fe1a14b3904860b4bfc3b230e5f474bc..0b0623c34e7cc460c06c9303f3548bfa1148b717 100644 --- a/sickbeard/providers/xthor.py +++ b/sickbeard/providers/xthor.py @@ -18,21 +18,13 @@ # along with SickRage. If not, see <http://www.gnu.org/licenses/>. import re -import datetime -import sickbeard -import generic import cookielib import urllib import requests -from sickbeard.bs4_parser import BS4Parser -from sickbeard.common import Quality + from sickbeard import logger -from sickbeard import show_name_helpers -from sickbeard import db -from sickbeard import helpers -from unidecode import unidecode -from sickbeard import classes -from sickbeard.helpers import sanitizeSceneName +from sickbeard.providers import generic +from sickbeard.bs4_parser import BS4Parser class XthorProvider(generic.TorrentProvider): @@ -42,7 +34,6 @@ class XthorProvider(generic.TorrentProvider): generic.TorrentProvider.__init__(self, "Xthor") self.supportsBacklog = True - self.public = False self.cj = cookielib.CookieJar() @@ -50,7 +41,6 @@ class XthorProvider(generic.TorrentProvider): self.urlsearch = "https://xthor.bz/browse.php?search=\"%s\"%s" self.categories = "&searchin=title&incldead=0" - self.enabled = False self.username = None self.password = None self.ratio = None @@ -65,10 +55,9 @@ class XthorProvider(generic.TorrentProvider): login_params = {'username': self.username, 'password': self.password, - 'submitme': 'X' - } + 'submitme': 'X'} - response = self.getURL(self.url + '/takelogin.php', post_data=login_params, timeout=30) + response = self.getURL(self.url + '/takelogin.php', post_data=login_params, timeout=30) if not response: logger.log(u"Unable to connect to provider", logger.WARNING) return False @@ -105,14 +94,14 @@ class XthorProvider(generic.TorrentProvider): continue with BS4Parser(data, features=["html5lib", "permissive"]) as html: - resultsTable = html.find("table", { "class" : "table2 table-bordered2" }) + resultsTable = html.find("table", {"class" : "table2 table-bordered2"}) if resultsTable: rows = resultsTable.findAll("tr") for row in rows: - link = row.find("a",href=re.compile("details.php")) + link = row.find("a", href=re.compile("details.php")) if link: title = link.text - download_url = self.url + '/' + row.find("a",href=re.compile("download.php"))['href'] + download_url = self.url + '/' + row.find("a", href=re.compile("download.php"))['href'] #FIXME size = -1 seeders = 1 @@ -143,32 +132,4 @@ class XthorProvider(generic.TorrentProvider): def seedRatio(self): return self.ratio - def findPropers(self, search_date=datetime.datetime.today()): - - results = [] - - myDB = db.DBConnection() - sqlResults = myDB.select( - 'SELECT s.show_name, e.showid, e.season, e.episode, e.status, e.airdate FROM tv_episodes AS e' + - ' INNER JOIN tv_shows AS s ON (e.showid = s.indexer_id)' + - ' WHERE e.airdate >= ' + str(search_date.toordinal()) + - ' AND (e.status IN (' + ','.join([str(x) for x in Quality.DOWNLOADED]) + ')' + - ' OR (e.status IN (' + ','.join([str(x) for x in Quality.SNATCHED]) + ')))' - ) - - if not sqlResults: - return results - - for sqlshow in sqlResults: - self.show = helpers.findCertainShow(sickbeard.showList, int(sqlshow["showid"])) - if self.show: - curEp = self.show.getEpisode(int(sqlshow["season"]), int(sqlshow["episode"])) - search_params = self._get_episode_search_strings(curEp, add_string='PROPER|REPACK') - - for item in self._doSearch(search_params[0]): - title, url = self._get_title_and_url(item) - results.append(classes.Proper(title, url, datetime.datetime.today(), self.show)) - - return results - provider = XthorProvider() diff --git a/sickbeard/show_name_helpers.py b/sickbeard/show_name_helpers.py index 710652736aa6f94eb39775aab7c19660956b0519..392d17af0612b7ac30b3ebbdde5a1f66ff764527 100644 --- a/sickbeard/show_name_helpers.py +++ b/sickbeard/show_name_helpers.py @@ -34,13 +34,15 @@ from name_parser.parser import NameParser, InvalidNameException, InvalidShowExce resultFilters = [ "sub(bed|ed|pack|s)", - "(dk|fin|heb|kor|nor|nordic|pl|swe)sub(bed|ed|s)?", "(dir|sample|sub|nfo)fix", "sample", "(dvd)?extras", "dub(bed)?" ] +if hasattr('General','ignored_subs_list') and sickbeard.IGNORED_SUBS_LIST: + resultFilters.append("(" + sickbeard.IGNORED_SUBS_LIST.replace(",", "|") + ")sub(bed|ed|s)?") + def containsAtLeastOneWord(name, words): """ diff --git a/sickbeard/tv.py b/sickbeard/tv.py index 6ee0a7fb61056f92f5c09c1e89fd51bde4dc109c..b7d7798330fa43ff5f8f266d9e44ad6b3ea8bec1 100644 --- a/sickbeard/tv.py +++ b/sickbeard/tv.py @@ -2541,7 +2541,7 @@ class TVEpisode(object): min = int((airs.group(2), min)[None is airs.group(2)]) airtime = datetime.time(hr, min) - if sickbeard.TIMEZONE_DISPLAY == 'local': + if sickbeard.FILE_TIMESTAMP_TIMEZONE == 'local': airdatetime = sbdatetime.sbdatetime.convert_to_setting( network_timezones.parse_date_time(datetime.date.toordinal(self.airdate), self.show.airs, self.show.network)) else: airdatetime = datetime.datetime.combine(self.airdate, airtime).replace(tzinfo=tzlocal()) diff --git a/sickbeard/webapi.py b/sickbeard/webapi.py index 0c33ea750f0e74aa3851f3de2c350d098d0a7e21..87a796208440a4cc289adcbf3e7cbca56d0d03f8 100644 --- a/sickbeard/webapi.py +++ b/sickbeard/webapi.py @@ -477,16 +477,9 @@ class TVDBShorthandWrapper(ApiCall): # ############################### -# helper functions # +# helper functions # # ############################### -def _sizeof_fmt(num): - for x in ['bytes', 'KB', 'MB', 'GB', 'TB']: - if num < 1024.00: - return "%3.2f %s" % (num, x) - num /= 1024.00 - - def _is_int(data): try: int(data) @@ -762,7 +755,7 @@ class CMD_Episode(ApiCall): status, quality = Quality.splitCompositeStatus(int(episode["status"])) episode["status"] = _get_status_Strings(status) episode["quality"] = get_quality_string(quality) - episode["file_size_human"] = _sizeof_fmt(episode["file_size"]) + episode["file_size_human"] = helpers.pretty_filesize(episode["file_size"]) return _responds(RESULT_SUCCESS, episode) diff --git a/sickbeard/webserve.py b/sickbeard/webserve.py index c3a4e932c16154c527a1deae7e946f4a14db8e64..64eec882cbb26120a5b1ff985d2281eee335c43f 100644 --- a/sickbeard/webserve.py +++ b/sickbeard/webserve.py @@ -63,6 +63,7 @@ from sickrage.show.History import History as HistoryTool from sickrage.show.Show import Show from sickrage.system.Restart import Restart from sickrage.system.Shutdown import Shutdown + from versionChecker import CheckVersion import requests @@ -3805,7 +3806,7 @@ class ConfigSearch(Config): torrent_dir=None, torrent_username=None, torrent_password=None, torrent_host=None, torrent_label=None, torrent_label_anime=None, torrent_path=None, torrent_verify_cert=None, torrent_seed_time=None, torrent_paused=None, torrent_high_bandwidth=None, - torrent_rpcurl=None, torrent_auth_type = None, ignore_words=None, require_words=None): + torrent_rpcurl=None, torrent_auth_type = None, ignore_words=None, require_words=None, ignored_subs_list=None): results = [] @@ -3828,6 +3829,7 @@ class ConfigSearch(Config): sickbeard.IGNORE_WORDS = ignore_words if ignore_words else "" sickbeard.REQUIRE_WORDS = require_words if require_words else "" + sickbeard.IGNORED_SUBS_LIST = ignored_subs_list if ignored_subs_list else "" sickbeard.RANDOMIZE_PROVIDERS = config.checkbox_to_value(randomize_providers) @@ -3896,7 +3898,7 @@ class ConfigPostProcessing(Config): kodi_data=None, kodi_12plus_data=None, mediabrowser_data=None, sony_ps3_data=None, wdtv_data=None, tivo_data=None, mede8er_data=None, keep_processed_dir=None, process_method=None, del_rar_contents=None, process_automatically=None, - no_delete=None, rename_episodes=None, airdate_episodes=None, unpack=None, + no_delete=None, rename_episodes=None, airdate_episodes=None, file_timestamp_timezone=None, unpack=None, move_associated_files=None, sync_files=None, postpone_if_sync_files=None, nfo_rename=None, tv_download_dir=None, naming_custom_abd=None, naming_anime=None,create_missing_show_dirs=None,add_shows_wo_dir=None, @@ -3931,6 +3933,7 @@ class ConfigPostProcessing(Config): sickbeard.EXTRA_SCRIPTS = [x.strip() for x in extra_scripts.split('|') if x.strip()] sickbeard.RENAME_EPISODES = config.checkbox_to_value(rename_episodes) sickbeard.AIRDATE_EPISODES = config.checkbox_to_value(airdate_episodes) + sickbeard.FILE_TIMESTAMP_TIMEZONE = file_timestamp_timezone sickbeard.MOVE_ASSOCIATED_FILES = config.checkbox_to_value(move_associated_files) sickbeard.SYNC_FILES = sync_files sickbeard.POSTPONE_IF_SYNC_FILES = config.checkbox_to_value(postpone_if_sync_files) diff --git a/tornado/__init__.py b/tornado/__init__.py index 6f4f47d2d9f7f60e8c18e4583579e2681f56e742..5588295e49e064b1c542aeca0ff94eb0178bd529 100644 --- a/tornado/__init__.py +++ b/tornado/__init__.py @@ -25,5 +25,5 @@ from __future__ import absolute_import, division, print_function, with_statement # is zero for an official release, positive for a development branch, # or negative for a release candidate or beta (after the base version # number has been incremented) -version = "4.1" -version_info = (4, 1, 0, 0) +version = "4.2.1" +version_info = (4, 2, 1, 0) diff --git a/tornado/auth.py b/tornado/auth.py index ac2fd0d198af06bac288c64421fc099a634fdde1..800b10afe455088e49326f0aaf8ab432c8842ed0 100644 --- a/tornado/auth.py +++ b/tornado/auth.py @@ -32,7 +32,9 @@ They all take slightly different arguments due to the fact all these services implement authentication and authorization slightly differently. See the individual service classes below for complete documentation. -Example usage for Google OpenID:: +Example usage for Google OAuth: + +.. testcode:: class GoogleOAuth2LoginHandler(tornado.web.RequestHandler, tornado.auth.GoogleOAuth2Mixin): @@ -51,6 +53,10 @@ Example usage for Google OpenID:: response_type='code', extra_params={'approval_prompt': 'auto'}) +.. testoutput:: + :hide: + + .. versionchanged:: 4.0 All of the callback interfaces in this module are now guaranteed to run their callback with an argument of ``None`` on error. @@ -69,7 +75,7 @@ import hmac import time import uuid -from tornado.concurrent import TracebackFuture, chain_future, return_future +from tornado.concurrent import TracebackFuture, return_future from tornado import gen from tornado import httpclient from tornado import escape @@ -123,6 +129,7 @@ def _auth_return_future(f): if callback is not None: future.add_done_callback( functools.partial(_auth_future_to_callback, callback)) + def handle_exception(typ, value, tb): if future.done(): return False @@ -138,9 +145,6 @@ def _auth_return_future(f): class OpenIdMixin(object): """Abstract implementation of OpenID and Attribute Exchange. - See `GoogleMixin` below for a customized example (which also - includes OAuth support). - Class attributes: * ``_OPENID_ENDPOINT``: the identity provider's URI. @@ -312,8 +316,7 @@ class OpenIdMixin(object): class OAuthMixin(object): """Abstract implementation of OAuth 1.0 and 1.0a. - See `TwitterMixin` and `FriendFeedMixin` below for example implementations, - or `GoogleMixin` for an OAuth/OpenID hybrid. + See `TwitterMixin` below for an example implementation. Class attributes: @@ -565,7 +568,8 @@ class OAuthMixin(object): class OAuth2Mixin(object): """Abstract implementation of OAuth 2.0. - See `FacebookGraphMixin` below for an example implementation. + See `FacebookGraphMixin` or `GoogleOAuth2Mixin` below for example + implementations. Class attributes: @@ -629,7 +633,9 @@ class TwitterMixin(OAuthMixin): URL you registered as your application's callback URL. When your application is set up, you can use this mixin like this - to authenticate the user with Twitter and get access to their stream:: + to authenticate the user with Twitter and get access to their stream: + + .. testcode:: class TwitterLoginHandler(tornado.web.RequestHandler, tornado.auth.TwitterMixin): @@ -641,6 +647,9 @@ class TwitterMixin(OAuthMixin): else: yield self.authorize_redirect() + .. testoutput:: + :hide: + The user object returned by `~OAuthMixin.get_authenticated_user` includes the attributes ``username``, ``name``, ``access_token``, and all of the custom Twitter user attributes described at @@ -689,7 +698,9 @@ class TwitterMixin(OAuthMixin): `~OAuthMixin.get_authenticated_user`. The user returned through that process includes an 'access_token' attribute that can be used to make authenticated requests via this method. Example - usage:: + usage: + + .. testcode:: class MainHandler(tornado.web.RequestHandler, tornado.auth.TwitterMixin): @@ -706,6 +717,9 @@ class TwitterMixin(OAuthMixin): return self.finish("Posted a message!") + .. testoutput:: + :hide: + """ if path.startswith('http:') or path.startswith('https:'): # Raw urls are useful for e.g. search which doesn't follow the @@ -757,223 +771,6 @@ class TwitterMixin(OAuthMixin): raise gen.Return(user) -class FriendFeedMixin(OAuthMixin): - """FriendFeed OAuth authentication. - - To authenticate with FriendFeed, register your application with - FriendFeed at http://friendfeed.com/api/applications. Then copy - your Consumer Key and Consumer Secret to the application - `~tornado.web.Application.settings` ``friendfeed_consumer_key`` - and ``friendfeed_consumer_secret``. Use this mixin on the handler - for the URL you registered as your application's Callback URL. - - When your application is set up, you can use this mixin like this - to authenticate the user with FriendFeed and get access to their feed:: - - class FriendFeedLoginHandler(tornado.web.RequestHandler, - tornado.auth.FriendFeedMixin): - @tornado.gen.coroutine - def get(self): - if self.get_argument("oauth_token", None): - user = yield self.get_authenticated_user() - # Save the user using e.g. set_secure_cookie() - else: - yield self.authorize_redirect() - - The user object returned by `~OAuthMixin.get_authenticated_user()` includes the - attributes ``username``, ``name``, and ``description`` in addition to - ``access_token``. You should save the access token with the user; - it is required to make requests on behalf of the user later with - `friendfeed_request()`. - """ - _OAUTH_VERSION = "1.0" - _OAUTH_REQUEST_TOKEN_URL = "https://friendfeed.com/account/oauth/request_token" - _OAUTH_ACCESS_TOKEN_URL = "https://friendfeed.com/account/oauth/access_token" - _OAUTH_AUTHORIZE_URL = "https://friendfeed.com/account/oauth/authorize" - _OAUTH_NO_CALLBACKS = True - _OAUTH_VERSION = "1.0" - - @_auth_return_future - def friendfeed_request(self, path, callback, access_token=None, - post_args=None, **args): - """Fetches the given relative API path, e.g., "/bret/friends" - - If the request is a POST, ``post_args`` should be provided. Query - string arguments should be given as keyword arguments. - - All the FriendFeed methods are documented at - http://friendfeed.com/api/documentation. - - Many methods require an OAuth access token which you can - obtain through `~OAuthMixin.authorize_redirect` and - `~OAuthMixin.get_authenticated_user`. The user returned - through that process includes an ``access_token`` attribute that - can be used to make authenticated requests via this - method. - - Example usage:: - - class MainHandler(tornado.web.RequestHandler, - tornado.auth.FriendFeedMixin): - @tornado.web.authenticated - @tornado.gen.coroutine - def get(self): - new_entry = yield self.friendfeed_request( - "/entry", - post_args={"body": "Testing Tornado Web Server"}, - access_token=self.current_user["access_token"]) - - if not new_entry: - # Call failed; perhaps missing permission? - yield self.authorize_redirect() - return - self.finish("Posted a message!") - - """ - # Add the OAuth resource request signature if we have credentials - url = "http://friendfeed-api.com/v2" + path - if access_token: - all_args = {} - all_args.update(args) - all_args.update(post_args or {}) - method = "POST" if post_args is not None else "GET" - oauth = self._oauth_request_parameters( - url, access_token, all_args, method=method) - args.update(oauth) - if args: - url += "?" + urllib_parse.urlencode(args) - callback = functools.partial(self._on_friendfeed_request, callback) - http = self.get_auth_http_client() - if post_args is not None: - http.fetch(url, method="POST", body=urllib_parse.urlencode(post_args), - callback=callback) - else: - http.fetch(url, callback=callback) - - def _on_friendfeed_request(self, future, response): - if response.error: - future.set_exception(AuthError( - "Error response %s fetching %s" % (response.error, - response.request.url))) - return - future.set_result(escape.json_decode(response.body)) - - def _oauth_consumer_token(self): - self.require_setting("friendfeed_consumer_key", "FriendFeed OAuth") - self.require_setting("friendfeed_consumer_secret", "FriendFeed OAuth") - return dict( - key=self.settings["friendfeed_consumer_key"], - secret=self.settings["friendfeed_consumer_secret"]) - - @gen.coroutine - def _oauth_get_user_future(self, access_token, callback): - user = yield self.friendfeed_request( - "/feedinfo/" + access_token["username"], - include="id,name,description", access_token=access_token) - if user: - user["username"] = user["id"] - callback(user) - - def _parse_user_response(self, callback, user): - if user: - user["username"] = user["id"] - callback(user) - - -class GoogleMixin(OpenIdMixin, OAuthMixin): - """Google Open ID / OAuth authentication. - - .. deprecated:: 4.0 - New applications should use `GoogleOAuth2Mixin` - below instead of this class. As of May 19, 2014, Google has stopped - supporting registration-free authentication. - - No application registration is necessary to use Google for - authentication or to access Google resources on behalf of a user. - - Google implements both OpenID and OAuth in a hybrid mode. If you - just need the user's identity, use - `~OpenIdMixin.authenticate_redirect`. If you need to make - requests to Google on behalf of the user, use - `authorize_redirect`. On return, parse the response with - `~OpenIdMixin.get_authenticated_user`. We send a dict containing - the values for the user, including ``email``, ``name``, and - ``locale``. - - Example usage:: - - class GoogleLoginHandler(tornado.web.RequestHandler, - tornado.auth.GoogleMixin): - @tornado.gen.coroutine - def get(self): - if self.get_argument("openid.mode", None): - user = yield self.get_authenticated_user() - # Save the user with e.g. set_secure_cookie() - else: - yield self.authenticate_redirect() - """ - _OPENID_ENDPOINT = "https://www.google.com/accounts/o8/ud" - _OAUTH_ACCESS_TOKEN_URL = "https://www.google.com/accounts/OAuthGetAccessToken" - - @return_future - def authorize_redirect(self, oauth_scope, callback_uri=None, - ax_attrs=["name", "email", "language", "username"], - callback=None): - """Authenticates and authorizes for the given Google resource. - - Some of the available resources which can be used in the ``oauth_scope`` - argument are: - - * Gmail Contacts - http://www.google.com/m8/feeds/ - * Calendar - http://www.google.com/calendar/feeds/ - * Finance - http://finance.google.com/finance/feeds/ - - You can authorize multiple resources by separating the resource - URLs with a space. - - .. versionchanged:: 3.1 - Returns a `.Future` and takes an optional callback. These are - not strictly necessary as this method is synchronous, - but they are supplied for consistency with - `OAuthMixin.authorize_redirect`. - """ - callback_uri = callback_uri or self.request.uri - args = self._openid_args(callback_uri, ax_attrs=ax_attrs, - oauth_scope=oauth_scope) - self.redirect(self._OPENID_ENDPOINT + "?" + urllib_parse.urlencode(args)) - callback() - - @_auth_return_future - def get_authenticated_user(self, callback): - """Fetches the authenticated user data upon redirect.""" - # Look to see if we are doing combined OpenID/OAuth - oauth_ns = "" - for name, values in self.request.arguments.items(): - if name.startswith("openid.ns.") and \ - values[-1] == b"http://specs.openid.net/extensions/oauth/1.0": - oauth_ns = name[10:] - break - token = self.get_argument("openid." + oauth_ns + ".request_token", "") - if token: - http = self.get_auth_http_client() - token = dict(key=token, secret="") - http.fetch(self._oauth_access_token_url(token), - functools.partial(self._on_access_token, callback)) - else: - chain_future(OpenIdMixin.get_authenticated_user(self), - callback) - - def _oauth_consumer_token(self): - self.require_setting("google_consumer_key", "Google OAuth") - self.require_setting("google_consumer_secret", "Google OAuth") - return dict( - key=self.settings["google_consumer_key"], - secret=self.settings["google_consumer_secret"]) - - def _oauth_get_user_future(self, access_token): - return OpenIdMixin.get_authenticated_user(self) - - class GoogleOAuth2Mixin(OAuth2Mixin): """Google authentication using OAuth2. @@ -1001,7 +798,9 @@ class GoogleOAuth2Mixin(OAuth2Mixin): def get_authenticated_user(self, redirect_uri, code, callback): """Handles the login for the Google user, returning a user object. - Example usage:: + Example usage: + + .. testcode:: class GoogleOAuth2LoginHandler(tornado.web.RequestHandler, tornado.auth.GoogleOAuth2Mixin): @@ -1019,6 +818,10 @@ class GoogleOAuth2Mixin(OAuth2Mixin): scope=['profile', 'email'], response_type='code', extra_params={'approval_prompt': 'auto'}) + + .. testoutput:: + :hide: + """ http = self.get_auth_http_client() body = urllib_parse.urlencode({ @@ -1051,217 +854,6 @@ class GoogleOAuth2Mixin(OAuth2Mixin): return httpclient.AsyncHTTPClient() -class FacebookMixin(object): - """Facebook Connect authentication. - - .. deprecated:: 1.1 - New applications should use `FacebookGraphMixin` - below instead of this class. This class does not support the - Future-based interface seen on other classes in this module. - - To authenticate with Facebook, register your application with - Facebook at http://www.facebook.com/developers/apps.php. Then - copy your API Key and Application Secret to the application settings - ``facebook_api_key`` and ``facebook_secret``. - - When your application is set up, you can use this mixin like this - to authenticate the user with Facebook:: - - class FacebookHandler(tornado.web.RequestHandler, - tornado.auth.FacebookMixin): - @tornado.web.asynchronous - def get(self): - if self.get_argument("session", None): - self.get_authenticated_user(self._on_auth) - return - yield self.authenticate_redirect() - - def _on_auth(self, user): - if not user: - raise tornado.web.HTTPError(500, "Facebook auth failed") - # Save the user using, e.g., set_secure_cookie() - - The user object returned by `get_authenticated_user` includes the - attributes ``facebook_uid`` and ``name`` in addition to session attributes - like ``session_key``. You should save the session key with the user; it is - required to make requests on behalf of the user later with - `facebook_request`. - """ - @return_future - def authenticate_redirect(self, callback_uri=None, cancel_uri=None, - extended_permissions=None, callback=None): - """Authenticates/installs this app for the current user. - - .. versionchanged:: 3.1 - Returns a `.Future` and takes an optional callback. These are - not strictly necessary as this method is synchronous, - but they are supplied for consistency with - `OAuthMixin.authorize_redirect`. - """ - self.require_setting("facebook_api_key", "Facebook Connect") - callback_uri = callback_uri or self.request.uri - args = { - "api_key": self.settings["facebook_api_key"], - "v": "1.0", - "fbconnect": "true", - "display": "page", - "next": urlparse.urljoin(self.request.full_url(), callback_uri), - "return_session": "true", - } - if cancel_uri: - args["cancel_url"] = urlparse.urljoin( - self.request.full_url(), cancel_uri) - if extended_permissions: - if isinstance(extended_permissions, (unicode_type, bytes)): - extended_permissions = [extended_permissions] - args["req_perms"] = ",".join(extended_permissions) - self.redirect("http://www.facebook.com/login.php?" + - urllib_parse.urlencode(args)) - callback() - - def authorize_redirect(self, extended_permissions, callback_uri=None, - cancel_uri=None, callback=None): - """Redirects to an authorization request for the given FB resource. - - The available resource names are listed at - http://wiki.developers.facebook.com/index.php/Extended_permission. - The most common resource types include: - - * publish_stream - * read_stream - * email - * sms - - extended_permissions can be a single permission name or a list of - names. To get the session secret and session key, call - get_authenticated_user() just as you would with - authenticate_redirect(). - - .. versionchanged:: 3.1 - Returns a `.Future` and takes an optional callback. These are - not strictly necessary as this method is synchronous, - but they are supplied for consistency with - `OAuthMixin.authorize_redirect`. - """ - return self.authenticate_redirect(callback_uri, cancel_uri, - extended_permissions, - callback=callback) - - def get_authenticated_user(self, callback): - """Fetches the authenticated Facebook user. - - The authenticated user includes the special Facebook attributes - 'session_key' and 'facebook_uid' in addition to the standard - user attributes like 'name'. - """ - self.require_setting("facebook_api_key", "Facebook Connect") - session = escape.json_decode(self.get_argument("session")) - self.facebook_request( - method="facebook.users.getInfo", - callback=functools.partial( - self._on_get_user_info, callback, session), - session_key=session["session_key"], - uids=session["uid"], - fields="uid,first_name,last_name,name,locale,pic_square," - "profile_url,username") - - def facebook_request(self, method, callback, **args): - """Makes a Facebook API REST request. - - We automatically include the Facebook API key and signature, but - it is the callers responsibility to include 'session_key' and any - other required arguments to the method. - - The available Facebook methods are documented here: - http://wiki.developers.facebook.com/index.php/API - - Here is an example for the stream.get() method:: - - class MainHandler(tornado.web.RequestHandler, - tornado.auth.FacebookMixin): - @tornado.web.authenticated - @tornado.web.asynchronous - def get(self): - self.facebook_request( - method="stream.get", - callback=self._on_stream, - session_key=self.current_user["session_key"]) - - def _on_stream(self, stream): - if stream is None: - # Not authorized to read the stream yet? - self.redirect(self.authorize_redirect("read_stream")) - return - self.render("stream.html", stream=stream) - - """ - self.require_setting("facebook_api_key", "Facebook Connect") - self.require_setting("facebook_secret", "Facebook Connect") - if not method.startswith("facebook."): - method = "facebook." + method - args["api_key"] = self.settings["facebook_api_key"] - args["v"] = "1.0" - args["method"] = method - args["call_id"] = str(long(time.time() * 1e6)) - args["format"] = "json" - args["sig"] = self._signature(args) - url = "http://api.facebook.com/restserver.php?" + \ - urllib_parse.urlencode(args) - http = self.get_auth_http_client() - http.fetch(url, callback=functools.partial( - self._parse_response, callback)) - - def _on_get_user_info(self, callback, session, users): - if users is None: - callback(None) - return - callback({ - "name": users[0]["name"], - "first_name": users[0]["first_name"], - "last_name": users[0]["last_name"], - "uid": users[0]["uid"], - "locale": users[0]["locale"], - "pic_square": users[0]["pic_square"], - "profile_url": users[0]["profile_url"], - "username": users[0].get("username"), - "session_key": session["session_key"], - "session_expires": session.get("expires"), - }) - - def _parse_response(self, callback, response): - if response.error: - gen_log.warning("HTTP error from Facebook: %s", response.error) - callback(None) - return - try: - json = escape.json_decode(response.body) - except Exception: - gen_log.warning("Invalid JSON from Facebook: %r", response.body) - callback(None) - return - if isinstance(json, dict) and json.get("error_code"): - gen_log.warning("Facebook error: %d: %r", json["error_code"], - json.get("error_msg")) - callback(None) - return - callback(json) - - def _signature(self, args): - parts = ["%s=%s" % (n, args[n]) for n in sorted(args.keys())] - body = "".join(parts) + self.settings["facebook_secret"] - if isinstance(body, unicode_type): - body = body.encode("utf-8") - return hashlib.md5(body).hexdigest() - - def get_auth_http_client(self): - """Returns the `.AsyncHTTPClient` instance to be used for auth requests. - - May be overridden by subclasses to use an HTTP client other than - the default. - """ - return httpclient.AsyncHTTPClient() - - class FacebookGraphMixin(OAuth2Mixin): """Facebook authentication using the new Graph API and OAuth2.""" _OAUTH_ACCESS_TOKEN_URL = "https://graph.facebook.com/oauth/access_token?" @@ -1274,9 +866,12 @@ class FacebookGraphMixin(OAuth2Mixin): code, callback, extra_fields=None): """Handles the login for the Facebook user, returning a user object. - Example usage:: + Example usage: - class FacebookGraphLoginHandler(LoginHandler, tornado.auth.FacebookGraphMixin): + .. testcode:: + + class FacebookGraphLoginHandler(tornado.web.RequestHandler, + tornado.auth.FacebookGraphMixin): @tornado.gen.coroutine def get(self): if self.get_argument("code", False): @@ -1291,6 +886,10 @@ class FacebookGraphMixin(OAuth2Mixin): redirect_uri='/auth/facebookgraph/', client_id=self.settings["facebook_api_key"], extra_params={"scope": "read_stream,offline_access"}) + + .. testoutput:: + :hide: + """ http = self.get_auth_http_client() args = { @@ -1307,7 +906,7 @@ class FacebookGraphMixin(OAuth2Mixin): http.fetch(self._oauth_request_token_url(**args), functools.partial(self._on_access_token, redirect_uri, client_id, - client_secret, callback, fields)) + client_secret, callback, fields)) def _on_access_token(self, redirect_uri, client_id, client_secret, future, fields, response): @@ -1358,7 +957,9 @@ class FacebookGraphMixin(OAuth2Mixin): process includes an ``access_token`` attribute that can be used to make authenticated requests via this method. - Example usage:: + Example usage: + + ..testcode:: class MainHandler(tornado.web.RequestHandler, tornado.auth.FacebookGraphMixin): @@ -1376,6 +977,9 @@ class FacebookGraphMixin(OAuth2Mixin): return self.finish("Posted a message!") + .. testoutput:: + :hide: + The given path is relative to ``self._FACEBOOK_BASE_URL``, by default "https://graph.facebook.com". diff --git a/tornado/autoreload.py b/tornado/autoreload.py index a548cf02624f1afc42edf15ea2fa5717f709d589..a52ddde40d497d92a89c1ea6fe29981b0a129a8e 100644 --- a/tornado/autoreload.py +++ b/tornado/autoreload.py @@ -100,6 +100,14 @@ try: except ImportError: signal = None +# os.execv is broken on Windows and can't properly parse command line +# arguments and executable name if they contain whitespaces. subprocess +# fixes that behavior. +# This distinction is also important because when we use execv, we want to +# close the IOLoop and all its file descriptors, to guard against any +# file descriptors that were not set CLOEXEC. When execv is not available, +# we must not close the IOLoop because we want the process to exit cleanly. +_has_execv = sys.platform != 'win32' _watched_files = set() _reload_hooks = [] @@ -119,7 +127,8 @@ def start(io_loop=None, check_time=500): _io_loops[io_loop] = True if len(_io_loops) > 1: gen_log.warning("tornado.autoreload started more than once in the same process") - add_reload_hook(functools.partial(io_loop.close, all_fds=True)) + if _has_execv: + add_reload_hook(functools.partial(io_loop.close, all_fds=True)) modify_times = {} callback = functools.partial(_reload_on_update, modify_times) scheduler = ioloop.PeriodicCallback(callback, check_time, io_loop=io_loop) @@ -166,7 +175,7 @@ def _reload_on_update(modify_times): # processes restarted themselves, they'd all restart and then # all call fork_processes again. return - for module in sys.modules.values(): + for module in list(sys.modules.values()): # Some modules play games with sys.modules (e.g. email/__init__.py # in the standard library), and occasionally this can cause strange # failures in getattr. Just ignore anything that's not an ordinary @@ -215,10 +224,7 @@ def _reload(): not os.environ.get("PYTHONPATH", "").startswith(path_prefix)): os.environ["PYTHONPATH"] = (path_prefix + os.environ.get("PYTHONPATH", "")) - if sys.platform == 'win32': - # os.execv is broken on Windows and can't properly parse command line - # arguments and executable name if they contain whitespaces. subprocess - # fixes that behavior. + if not _has_execv: subprocess.Popen([sys.executable] + sys.argv) sys.exit(0) else: @@ -238,7 +244,10 @@ def _reload(): # this error specifically. os.spawnv(os.P_NOWAIT, sys.executable, [sys.executable] + sys.argv) - sys.exit(0) + # At this point the IOLoop has been closed and finally + # blocks will experience errors if we allow the stack to + # unwind, so just exit uncleanly. + os._exit(0) _USAGE = """\ Usage: diff --git a/tornado/concurrent.py b/tornado/concurrent.py index acfbcd83e820dfd1ff1e6eb9be8af961d5ae0084..479ca022ef399d36883837d92ee7b12f055ab57b 100644 --- a/tornado/concurrent.py +++ b/tornado/concurrent.py @@ -44,12 +44,14 @@ except ImportError: _GC_CYCLE_FINALIZERS = (platform.python_implementation() == 'CPython' and sys.version_info >= (3, 4)) + class ReturnValueIgnoredError(Exception): pass # This class and associated code in the future object is derived # from the Trollius project, a backport of asyncio to Python 2.x - 3.x + class _TracebackLogger(object): """Helper to log a traceback upon destruction if not cleared. @@ -201,6 +203,10 @@ class Future(object): def result(self, timeout=None): """If the operation succeeded, return its result. If it failed, re-raise its exception. + + This method takes a ``timeout`` argument for compatibility with + `concurrent.futures.Future` but it is an error to call it + before the `Future` is done, so the ``timeout`` is never used. """ self._clear_tb_log() if self._result is not None: @@ -213,6 +219,10 @@ class Future(object): def exception(self, timeout=None): """If the operation raised an exception, return the `Exception` object. Otherwise returns None. + + This method takes a ``timeout`` argument for compatibility with + `concurrent.futures.Future` but it is an error to call it + before the `Future` is done, so the ``timeout`` is never used. """ self._clear_tb_log() if self._exc_info is not None: @@ -289,7 +299,7 @@ class Future(object): try: cb(self) except Exception: - app_log.exception('exception calling callback %r for %r', + app_log.exception('Exception in callback %r for %r', cb, self) self._callbacks = None @@ -335,24 +345,42 @@ class DummyExecutor(object): dummy_executor = DummyExecutor() -def run_on_executor(fn): +def run_on_executor(*args, **kwargs): """Decorator to run a synchronous method asynchronously on an executor. The decorated method may be called with a ``callback`` keyword argument and returns a future. - This decorator should be used only on methods of objects with attributes - ``executor`` and ``io_loop``. + The `.IOLoop` and executor to be used are determined by the ``io_loop`` + and ``executor`` attributes of ``self``. To use different attributes, + pass keyword arguments to the decorator:: + + @run_on_executor(executor='_thread_pool') + def foo(self): + pass + + .. versionchanged:: 4.2 + Added keyword arguments to use alternative attributes. """ - @functools.wraps(fn) - def wrapper(self, *args, **kwargs): - callback = kwargs.pop("callback", None) - future = self.executor.submit(fn, self, *args, **kwargs) - if callback: - self.io_loop.add_future(future, - lambda future: callback(future.result())) - return future - return wrapper + def run_on_executor_decorator(fn): + executor = kwargs.get("executor", "executor") + io_loop = kwargs.get("io_loop", "io_loop") + @functools.wraps(fn) + def wrapper(self, *args, **kwargs): + callback = kwargs.pop("callback", None) + future = getattr(self, executor).submit(fn, self, *args, **kwargs) + if callback: + getattr(self, io_loop).add_future( + future, lambda future: callback(future.result())) + return future + return wrapper + if args and kwargs: + raise ValueError("cannot combine positional and keyword args") + if len(args) == 1: + return run_on_executor_decorator(args[0]) + elif len(args) != 0: + raise ValueError("expected 1 argument, got %d", len(args)) + return run_on_executor_decorator _NO_RESULT = object() @@ -377,7 +405,9 @@ def return_future(f): wait for the function to complete (perhaps by yielding it in a `.gen.engine` function, or passing it to `.IOLoop.add_future`). - Usage:: + Usage: + + .. testcode:: @return_future def future_func(arg1, arg2, callback): @@ -389,6 +419,8 @@ def return_future(f): yield future_func(arg1, arg2) callback() + .. + Note that ``@return_future`` and ``@gen.engine`` can be applied to the same function, provided ``@return_future`` appears first. However, consider using ``@gen.coroutine`` instead of this combination. diff --git a/tornado/curl_httpclient.py b/tornado/curl_httpclient.py index ebbe0e84b9300485d1acae2ed405b4cf512b0e21..ae6f114a95b92912c30297e48936c43119c673f2 100644 --- a/tornado/curl_httpclient.py +++ b/tornado/curl_httpclient.py @@ -35,6 +35,7 @@ from tornado.httpclient import HTTPResponse, HTTPError, AsyncHTTPClient, main curl_log = logging.getLogger('tornado.curl_httpclient') + class CurlAsyncHTTPClient(AsyncHTTPClient): def initialize(self, io_loop, max_clients=10, defaults=None): super(CurlAsyncHTTPClient, self).initialize(io_loop, defaults=defaults) @@ -207,9 +208,25 @@ class CurlAsyncHTTPClient(AsyncHTTPClient): "callback": callback, "curl_start_time": time.time(), } - self._curl_setup_request(curl, request, curl.info["buffer"], - curl.info["headers"]) - self._multi.add_handle(curl) + try: + self._curl_setup_request( + curl, request, curl.info["buffer"], + curl.info["headers"]) + except Exception as e: + # If there was an error in setup, pass it on + # to the callback. Note that allowing the + # error to escape here will appear to work + # most of the time since we are still in the + # caller's original stack frame, but when + # _process_queue() is called from + # _finish_pending_requests the exceptions have + # nowhere to go. + callback(HTTPResponse( + request=request, + code=599, + error=e)) + else: + self._multi.add_handle(curl) if not started: break @@ -286,10 +303,10 @@ class CurlAsyncHTTPClient(AsyncHTTPClient): curl.setopt(pycurl.HEADERFUNCTION, functools.partial(self._curl_header_callback, - headers, request.header_callback)) + headers, request.header_callback)) if request.streaming_callback: - write_function = lambda chunk: self.io_loop.add_callback( - request.streaming_callback, chunk) + def write_function(chunk): + self.io_loop.add_callback(request.streaming_callback, chunk) else: write_function = buffer.write if bytes is str: # py2 @@ -381,6 +398,7 @@ class CurlAsyncHTTPClient(AsyncHTTPClient): % request.method) request_buffer = BytesIO(utf8(request.body)) + def ioctl(cmd): if cmd == curl.IOCMD_RESTARTREAD: request_buffer.seek(0) @@ -404,7 +422,7 @@ class CurlAsyncHTTPClient(AsyncHTTPClient): curl.setopt(pycurl.USERPWD, native_str(userpwd)) curl_log.debug("%s %s (username: %r)", request.method, request.url, - request.auth_username) + request.auth_username) else: curl.unsetopt(pycurl.USERPWD) curl_log.debug("%s %s", request.method, request.url) @@ -415,6 +433,9 @@ class CurlAsyncHTTPClient(AsyncHTTPClient): if request.client_key is not None: curl.setopt(pycurl.SSLKEY, request.client_key) + if request.ssl_options is not None: + raise ValueError("ssl_options not supported in curl_httpclient") + if threading.activeCount() > 1: # libcurl/pycurl is not thread-safe by default. When multiple threads # are used, signals should be disabled. This has the side effect diff --git a/tornado/escape.py b/tornado/escape.py index 24be2264ba85b54ee8efaf24c58ccafc7e6bd335..2f04b4683ae7cccecbbb2571d354c4a40eea3a9d 100644 --- a/tornado/escape.py +++ b/tornado/escape.py @@ -82,7 +82,7 @@ def json_encode(value): # JSON permits but does not require forward slashes to be escaped. # This is useful when json data is emitted in a <script> tag # in HTML, as it prevents </script> tags from prematurely terminating - # the javscript. Some json libraries do this escaping by default, + # the javascript. Some json libraries do this escaping by default, # although python's standard library does not, so we do it here. # http://stackoverflow.com/questions/1580647/json-why-are-forward-slashes-escaped return json.dumps(value).replace("</", "<\\/") @@ -378,7 +378,10 @@ def linkify(text, shorten=False, extra_params="", def _convert_entity(m): if m.group(1) == "#": try: - return unichr(int(m.group(2))) + if m.group(2)[:1].lower() == 'x': + return unichr(int(m.group(2)[1:], 16)) + else: + return unichr(int(m.group(2))) except ValueError: return "&#%s;" % m.group(2) try: diff --git a/tornado/gen.py b/tornado/gen.py index 86fe2f19596ea77267af838b65c7faa03541b16a..9145768951dbe035044bf442241222d308b79b75 100644 --- a/tornado/gen.py +++ b/tornado/gen.py @@ -3,7 +3,9 @@ work in an asynchronous environment. Code using the ``gen`` module is technically asynchronous, but it is written as a single generator instead of a collection of separate functions. -For example, the following asynchronous handler:: +For example, the following asynchronous handler: + +.. testcode:: class AsyncHandler(RequestHandler): @asynchronous @@ -16,7 +18,12 @@ For example, the following asynchronous handler:: do_something_with_response(response) self.render("template.html") -could be written with ``gen`` as:: +.. testoutput:: + :hide: + +could be written with ``gen`` as: + +.. testcode:: class GenAsyncHandler(RequestHandler): @gen.coroutine @@ -26,12 +33,17 @@ could be written with ``gen`` as:: do_something_with_response(response) self.render("template.html") +.. testoutput:: + :hide: + Most asynchronous functions in Tornado return a `.Future`; yielding this object returns its `~.Future.result`. You can also yield a list or dict of ``Futures``, which will be started at the same time and run in parallel; a list or dict of results will -be returned when they are all finished:: +be returned when they are all finished: + +.. testcode:: @gen.coroutine def get(self): @@ -43,6 +55,9 @@ be returned when they are all finished:: response3 = response_dict['response3'] response4 = response_dict['response4'] +.. testoutput:: + :hide: + If the `~functools.singledispatch` library is available (standard in Python 3.4, available via the `singledispatch <https://pypi.python.org/pypi/singledispatch>`_ package on older @@ -72,6 +87,7 @@ from tornado.concurrent import Future, TracebackFuture, is_future, chain_future from tornado.ioloop import IOLoop from tornado.log import app_log from tornado import stack_context +from tornado.util import raise_exc_info try: from functools import singledispatch # py34+ @@ -124,9 +140,11 @@ def engine(func): which use ``self.finish()`` in place of a callback argument. """ func = _make_coroutine_wrapper(func, replace_callback=False) + @functools.wraps(func) def wrapper(*args, **kwargs): future = func(*args, **kwargs) + def final_callback(future): if future.result() is not None: raise ReturnValueIgnoredError( @@ -263,6 +281,7 @@ class Return(Exception): super(Return, self).__init__() self.value = value + class WaitIterator(object): """Provides an iterator to yield the results of futures as they finish. @@ -277,20 +296,18 @@ class WaitIterator(object): If you need to get the result of each future as soon as possible, or if you need the result of some futures even if others produce - errors, you can use ``WaitIterator``: - - :: + errors, you can use ``WaitIterator``:: wait_iterator = gen.WaitIterator(future1, future2) while not wait_iterator.done(): try: result = yield wait_iterator.next() except Exception as e: - print "Error {} from {}".format(e, wait_iterator.current_future) + print("Error {} from {}".format(e, wait_iterator.current_future)) else: - print "Result {} recieved from {} at {}".format( + print("Result {} received from {} at {}".format( result, wait_iterator.current_future, - wait_iterator.current_index) + wait_iterator.current_index)) Because results are returned as soon as they are available the output from the iterator *will not be in the same order as the @@ -319,10 +336,8 @@ class WaitIterator(object): self.current_index = self.current_future = None self._running_future = None - self_ref = weakref.ref(self) for future in futures: - future.add_done_callback(functools.partial( - self._done_callback, self_ref)) + future.add_done_callback(self._done_callback) def done(self): """Returns True if this iterator has no more results.""" @@ -345,14 +360,11 @@ class WaitIterator(object): return self._running_future - @staticmethod - def _done_callback(self_ref, done): - self = self_ref() - if self is not None: - if self._running_future and not self._running_future.done(): - self._return_result(done) - else: - self._finished.append(done) + def _done_callback(self, done): + if self._running_future and not self._running_future.done(): + self._return_result(done) + else: + self._finished.append(done) def _return_result(self, done): """Called set the returned future's state that of the future @@ -478,11 +490,13 @@ def Task(func, *args, **kwargs): yielded. """ future = Future() + def handle_exception(typ, value, tb): if future.done(): return False future.set_exc_info((typ, value, tb)) return True + def set_result(result): if future.done(): return @@ -510,7 +524,7 @@ class YieldFuture(YieldPoint): self.io_loop.add_future(self.future, runner.result_callback(self.key)) else: self.runner = None - self.result = self.future.result() + self.result_fn = self.future.result def is_ready(self): if self.runner is not None: @@ -522,7 +536,7 @@ class YieldFuture(YieldPoint): if self.runner is not None: return self.runner.pop_result(self.key).result() else: - return self.result + return self.result_fn() class Multi(YieldPoint): @@ -536,8 +550,18 @@ class Multi(YieldPoint): Instead of a list, the argument may also be a dictionary whose values are Futures, in which case a parallel dictionary is returned mapping the same keys to their results. + + It is not normally necessary to call this class directly, as it + will be created automatically as needed. However, calling it directly + allows you to use the ``quiet_exceptions`` argument to control + the logging of multiple exceptions. + + .. versionchanged:: 4.2 + If multiple ``YieldPoints`` fail, any exceptions after the first + (which is raised) will be logged. Added the ``quiet_exceptions`` + argument to suppress this logging for selected exception types. """ - def __init__(self, children): + def __init__(self, children, quiet_exceptions=()): self.keys = None if isinstance(children, dict): self.keys = list(children.keys()) @@ -549,6 +573,7 @@ class Multi(YieldPoint): self.children.append(i) assert all(isinstance(i, YieldPoint) for i in self.children) self.unfinished_children = set(self.children) + self.quiet_exceptions = quiet_exceptions def start(self, runner): for i in self.children: @@ -561,14 +586,27 @@ class Multi(YieldPoint): return not self.unfinished_children def get_result(self): - result = (i.get_result() for i in self.children) + result_list = [] + exc_info = None + for f in self.children: + try: + result_list.append(f.get_result()) + except Exception as e: + if exc_info is None: + exc_info = sys.exc_info() + else: + if not isinstance(e, self.quiet_exceptions): + app_log.error("Multiple exceptions in yield list", + exc_info=True) + if exc_info is not None: + raise_exc_info(exc_info) if self.keys is not None: - return dict(zip(self.keys, result)) + return dict(zip(self.keys, result_list)) else: - return list(result) + return list(result_list) -def multi_future(children): +def multi_future(children, quiet_exceptions=()): """Wait for multiple asynchronous futures in parallel. Takes a list of ``Futures`` (but *not* other ``YieldPoints``) and returns @@ -581,12 +619,21 @@ def multi_future(children): Futures, in which case a parallel dictionary is returned mapping the same keys to their results. - It is not necessary to call `multi_future` explcitly, since the engine will - do so automatically when the generator yields a list of `Futures`. - This function is faster than the `Multi` `YieldPoint` because it does not - require the creation of a stack context. + It is not normally necessary to call `multi_future` explcitly, + since the engine will do so automatically when the generator + yields a list of ``Futures``. However, calling it directly + allows you to use the ``quiet_exceptions`` argument to control + the logging of multiple exceptions. + + This function is faster than the `Multi` `YieldPoint` because it + does not require the creation of a stack context. .. versionadded:: 4.0 + + .. versionchanged:: 4.2 + If multiple ``Futures`` fail, any exceptions after the first (which is + raised) will be logged. Added the ``quiet_exceptions`` + argument to suppress this logging for selected exception types. """ if isinstance(children, dict): keys = list(children.keys()) @@ -599,20 +646,32 @@ def multi_future(children): future = Future() if not children: future.set_result({} if keys is not None else []) + def callback(f): unfinished_children.remove(f) if not unfinished_children: - try: - result_list = [i.result() for i in children] - except Exception: - future.set_exc_info(sys.exc_info()) - else: + result_list = [] + for f in children: + try: + result_list.append(f.result()) + except Exception as e: + if future.done(): + if not isinstance(e, quiet_exceptions): + app_log.error("Multiple exceptions in yield list", + exc_info=True) + else: + future.set_exc_info(sys.exc_info()) + if not future.done(): if keys is not None: future.set_result(dict(zip(keys, result_list))) else: future.set_result(result_list) + + listening = set() for f in children: - f.add_done_callback(callback) + if f not in listening: + listening.add(f) + f.add_done_callback(callback) return future @@ -664,6 +723,7 @@ def with_timeout(timeout, future, io_loop=None, quiet_exceptions=()): chain_future(future, result) if io_loop is None: io_loop = IOLoop.current() + def error_callback(future): try: future.result() @@ -671,6 +731,7 @@ def with_timeout(timeout, future, io_loop=None, quiet_exceptions=()): if not isinstance(e, quiet_exceptions): app_log.error("Exception in Future %r after timeout", future, exc_info=True) + def timeout_callback(): result.set_exception(TimeoutError("Timeout")) # In case the wrapped future goes on to fail, log it. @@ -803,13 +864,20 @@ class Runner(object): self.future = None try: orig_stack_contexts = stack_context._state.contexts + exc_info = None + try: value = future.result() except Exception: self.had_exception = True - yielded = self.gen.throw(*sys.exc_info()) + exc_info = sys.exc_info() + + if exc_info is not None: + yielded = self.gen.throw(*exc_info) + exc_info = None else: yielded = self.gen.send(value) + if stack_context._state.contexts is not orig_stack_contexts: self.gen.throw( stack_context.StackContextInconsistentError( @@ -846,16 +914,17 @@ class Runner(object): # Lists containing YieldPoints require stack contexts; # other lists are handled via multi_future in convert_yielded. if (isinstance(yielded, list) and - any(isinstance(f, YieldPoint) for f in yielded)): + any(isinstance(f, YieldPoint) for f in yielded)): yielded = Multi(yielded) elif (isinstance(yielded, dict) and - any(isinstance(f, YieldPoint) for f in yielded.values())): + any(isinstance(f, YieldPoint) for f in yielded.values())): yielded = Multi(yielded) if isinstance(yielded, YieldPoint): # YieldPoints are too closely coupled to the Runner to go # through the generic convert_yielded mechanism. self.future = TracebackFuture() + def start_yield_point(): try: yielded.start(self) @@ -874,6 +943,7 @@ class Runner(object): with stack_context.ExceptionStackContext( self.handle_exception) as deactivate: self.stack_context_deactivate = deactivate + def cb(): start_yield_point() self.run() diff --git a/tornado/http1connection.py b/tornado/http1connection.py index 181319c42e4e192c74cc4bb20dc80d51b3bbc86a..6226ef7af2c85d4489598bd1b9a48a3652a6afea 100644 --- a/tornado/http1connection.py +++ b/tornado/http1connection.py @@ -37,6 +37,7 @@ class _QuietException(Exception): def __init__(self): pass + class _ExceptionLoggingContext(object): """Used with the ``with`` statement when calling delegate methods to log any exceptions with the given logger. Any exceptions caught are @@ -53,6 +54,7 @@ class _ExceptionLoggingContext(object): self.logger.error("Uncaught exception", exc_info=(typ, value, tb)) raise _QuietException + class HTTP1ConnectionParameters(object): """Parameters for `.HTTP1Connection` and `.HTTP1ServerConnection`. """ @@ -202,7 +204,7 @@ class HTTP1Connection(httputil.HTTPConnection): # 1xx responses should never indicate the presence of # a body. if ('Content-Length' in headers or - 'Transfer-Encoding' in headers): + 'Transfer-Encoding' in headers): raise httputil.HTTPInputError( "Response code %d cannot have body" % code) # TODO: client delegates will get headers_received twice diff --git a/tornado/httpclient.py b/tornado/httpclient.py index 0ae9e4802fba353cd62e5854813232f88540561f..c2e6862361691251659b185f097ae389a1baa654 100644 --- a/tornado/httpclient.py +++ b/tornado/httpclient.py @@ -72,7 +72,7 @@ class HTTPClient(object): http_client.close() """ def __init__(self, async_client_class=None, **kwargs): - self._io_loop = IOLoop() + self._io_loop = IOLoop(make_current=False) if async_client_class is None: async_client_class = AsyncHTTPClient self._async_client = async_client_class(self._io_loop, **kwargs) @@ -100,7 +100,6 @@ class HTTPClient(object): """ response = self._io_loop.run_sync(functools.partial( self._async_client.fetch, request, **kwargs)) - response.rethrow() return response @@ -310,7 +309,8 @@ class HTTPRequest(object): validate_cert=None, ca_certs=None, allow_ipv6=None, client_key=None, client_cert=None, body_producer=None, - expect_100_continue=False, decompress_response=None): + expect_100_continue=False, decompress_response=None, + ssl_options=None): r"""All parameters except ``url`` are optional. :arg string url: URL to fetch @@ -380,12 +380,15 @@ class HTTPRequest(object): :arg string ca_certs: filename of CA certificates in PEM format, or None to use defaults. See note below when used with ``curl_httpclient``. - :arg bool allow_ipv6: Use IPv6 when available? Default is false in - ``simple_httpclient`` and true in ``curl_httpclient`` :arg string client_key: Filename for client SSL key, if any. See note below when used with ``curl_httpclient``. :arg string client_cert: Filename for client SSL certificate, if any. See note below when used with ``curl_httpclient``. + :arg ssl.SSLContext ssl_options: `ssl.SSLContext` object for use in + ``simple_httpclient`` (unsupported by ``curl_httpclient``). + Overrides ``validate_cert``, ``ca_certs``, ``client_key``, + and ``client_cert``. + :arg bool allow_ipv6: Use IPv6 when available? Default is true. :arg bool expect_100_continue: If true, send the ``Expect: 100-continue`` header and wait for a continue response before sending the request body. Only supported with @@ -408,6 +411,9 @@ class HTTPRequest(object): .. versionadded:: 4.0 The ``body_producer`` and ``expect_100_continue`` arguments. + + .. versionadded:: 4.2 + The ``ssl_options`` argument. """ # Note that some of these attributes go through property setters # defined below. @@ -445,6 +451,7 @@ class HTTPRequest(object): self.allow_ipv6 = allow_ipv6 self.client_key = client_key self.client_cert = client_cert + self.ssl_options = ssl_options self.expect_100_continue = expect_100_continue self.start_time = time.time() diff --git a/tornado/httpserver.py b/tornado/httpserver.py index e470e0e7d153418a940ccb4526007374f783ae90..2dd04dd7a87a0550f15ccab3e7a9ce61c8a2200e 100644 --- a/tornado/httpserver.py +++ b/tornado/httpserver.py @@ -37,9 +37,11 @@ from tornado import httputil from tornado import iostream from tornado import netutil from tornado.tcpserver import TCPServer +from tornado.util import Configurable -class HTTPServer(TCPServer, httputil.HTTPServerConnectionDelegate): +class HTTPServer(TCPServer, Configurable, + httputil.HTTPServerConnectionDelegate): r"""A non-blocking, single-threaded HTTP server. A server is defined by a subclass of `.HTTPServerConnectionDelegate`, @@ -60,15 +62,15 @@ class HTTPServer(TCPServer, httputil.HTTPServerConnectionDelegate): if Tornado is run behind an SSL-decoding proxy that does not set one of the supported ``xheaders``. - To make this server serve SSL traffic, send the ``ssl_options`` dictionary - argument with the arguments required for the `ssl.wrap_socket` method, - including ``certfile`` and ``keyfile``. (In Python 3.2+ you can pass - an `ssl.SSLContext` object instead of a dict):: + To make this server serve SSL traffic, send the ``ssl_options`` keyword + argument with an `ssl.SSLContext` object. For compatibility with older + versions of Python ``ssl_options`` may also be a dictionary of keyword + arguments for the `ssl.wrap_socket` method.:: - HTTPServer(applicaton, ssl_options={ - "certfile": os.path.join(data_dir, "mydomain.crt"), - "keyfile": os.path.join(data_dir, "mydomain.key"), - }) + ssl_ctx = ssl.create_default_context(ssl.Purpose.CLIENT_AUTH) + ssl_ctx.load_cert_chain(os.path.join(data_dir, "mydomain.crt"), + os.path.join(data_dir, "mydomain.key")) + HTTPServer(applicaton, ssl_options=ssl_ctx) `HTTPServer` initialization follows one of three patterns (the initialization methods are defined on `tornado.tcpserver.TCPServer`): @@ -77,7 +79,7 @@ class HTTPServer(TCPServer, httputil.HTTPServerConnectionDelegate): server = HTTPServer(app) server.listen(8888) - IOLoop.instance().start() + IOLoop.current().start() In many cases, `tornado.web.Application.listen` can be used to avoid the need to explicitly create the `HTTPServer`. @@ -88,7 +90,7 @@ class HTTPServer(TCPServer, httputil.HTTPServerConnectionDelegate): server = HTTPServer(app) server.bind(8888) server.start(0) # Forks multiple sub-processes - IOLoop.instance().start() + IOLoop.current().start() When using this interface, an `.IOLoop` must *not* be passed to the `HTTPServer` constructor. `~.TCPServer.start` will always start @@ -100,7 +102,7 @@ class HTTPServer(TCPServer, httputil.HTTPServerConnectionDelegate): tornado.process.fork_processes(0) server = HTTPServer(app) server.add_sockets(sockets) - IOLoop.instance().start() + IOLoop.current().start() The `~.TCPServer.add_sockets` interface is more complicated, but it can be used with `tornado.process.fork_processes` to @@ -119,13 +121,24 @@ class HTTPServer(TCPServer, httputil.HTTPServerConnectionDelegate): `.HTTPServerConnectionDelegate.start_request` is now called with two arguments ``(server_conn, request_conn)`` (in accordance with the documentation) instead of one ``(request_conn)``. + + .. versionchanged:: 4.2 + `HTTPServer` is now a subclass of `tornado.util.Configurable`. """ - def __init__(self, request_callback, no_keep_alive=False, io_loop=None, - xheaders=False, ssl_options=None, protocol=None, - decompress_request=False, - chunk_size=None, max_header_size=None, - idle_connection_timeout=None, body_timeout=None, - max_body_size=None, max_buffer_size=None): + def __init__(self, *args, **kwargs): + # Ignore args to __init__; real initialization belongs in + # initialize since we're Configurable. (there's something + # weird in initialization order between this class, + # Configurable, and TCPServer so we can't leave __init__ out + # completely) + pass + + def initialize(self, request_callback, no_keep_alive=False, io_loop=None, + xheaders=False, ssl_options=None, protocol=None, + decompress_request=False, + chunk_size=None, max_header_size=None, + idle_connection_timeout=None, body_timeout=None, + max_body_size=None, max_buffer_size=None): self.request_callback = request_callback self.no_keep_alive = no_keep_alive self.xheaders = xheaders @@ -142,6 +155,14 @@ class HTTPServer(TCPServer, httputil.HTTPServerConnectionDelegate): read_chunk_size=chunk_size) self._connections = set() + @classmethod + def configurable_base(cls): + return HTTPServer + + @classmethod + def configurable_default(cls): + return HTTPServer + @gen.coroutine def close_all_connections(self): while self._connections: diff --git a/tornado/httputil.py b/tornado/httputil.py index 9c99b3efa8ec820669dc7198825d295b693ce578..fa5e697c17f306b2c388e6d179e8834f91c2cfe5 100644 --- a/tornado/httputil.py +++ b/tornado/httputil.py @@ -234,6 +234,14 @@ class HTTPHeaders(dict): # default implementation returns dict(self), not the subclass return HTTPHeaders(self) + # Use our overridden copy method for the copy.copy module. + __copy__ = copy + + def __deepcopy__(self, memo_dict): + # Our values are immutable strings, so our standard copy is + # effectively a deep copy. + return self.copy() + class HTTPServerRequest(object): """A single HTTP request. @@ -385,6 +393,8 @@ class HTTPServerRequest(object): to write the response. """ assert isinstance(chunk, bytes) + assert self.version.startswith("HTTP/1."), \ + "deprecated interface only supported in HTTP/1.x" self.connection.write(chunk, callback=callback) def finish(self): @@ -411,15 +421,14 @@ class HTTPServerRequest(object): def get_ssl_certificate(self, binary_form=False): """Returns the client's SSL certificate, if any. - To use client certificates, the HTTPServer must have been constructed - with cert_reqs set in ssl_options, e.g.:: + To use client certificates, the HTTPServer's + `ssl.SSLContext.verify_mode` field must be set, e.g.:: - server = HTTPServer(app, - ssl_options=dict( - certfile="foo.crt", - keyfile="foo.key", - cert_reqs=ssl.CERT_REQUIRED, - ca_certs="cacert.crt")) + ssl_ctx = ssl.create_default_context(ssl.Purpose.CLIENT_AUTH) + ssl_ctx.load_cert_chain("foo.crt", "foo.key") + ssl_ctx.load_verify_locations("cacerts.pem") + ssl_ctx.verify_mode = ssl.CERT_REQUIRED + server = HTTPServer(app, ssl_options=ssl_ctx) By default, the return value is a dictionary (or None, if no client certificate is present). If ``binary_form`` is true, a @@ -884,6 +893,7 @@ def doctests(): import doctest return doctest.DocTestSuite() + def split_host_and_port(netloc): """Returns ``(host, port)`` tuple from ``netloc``. diff --git a/tornado/ioloop.py b/tornado/ioloop.py index 680dc4016a40f7bfea92a631ddfd8767aca5e6df..67e33b521f41c12e8b14a38f54d1961245cad487 100644 --- a/tornado/ioloop.py +++ b/tornado/ioloop.py @@ -41,6 +41,7 @@ import sys import threading import time import traceback +import math from tornado.concurrent import TracebackFuture, is_future from tornado.log import app_log, gen_log @@ -76,35 +77,52 @@ class IOLoop(Configurable): simultaneous connections, you should use a system that supports either ``epoll`` or ``kqueue``. - Example usage for a simple TCP server:: + Example usage for a simple TCP server: + + .. testcode:: import errno import functools - import ioloop + import tornado.ioloop import socket def connection_ready(sock, fd, events): while True: try: connection, address = sock.accept() - except socket.error, e: + except socket.error as e: if e.args[0] not in (errno.EWOULDBLOCK, errno.EAGAIN): raise return connection.setblocking(0) handle_connection(connection, address) - sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM, 0) - sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) - sock.setblocking(0) - sock.bind(("", port)) - sock.listen(128) - - io_loop = ioloop.IOLoop.instance() - callback = functools.partial(connection_ready, sock) - io_loop.add_handler(sock.fileno(), callback, io_loop.READ) - io_loop.start() - + if __name__ == '__main__': + sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM, 0) + sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) + sock.setblocking(0) + sock.bind(("", port)) + sock.listen(128) + + io_loop = tornado.ioloop.IOLoop.current() + callback = functools.partial(connection_ready, sock) + io_loop.add_handler(sock.fileno(), callback, io_loop.READ) + io_loop.start() + + .. testoutput:: + :hide: + + By default, a newly-constructed `IOLoop` becomes the thread's current + `IOLoop`, unless there already is a current `IOLoop`. This behavior + can be controlled with the ``make_current`` argument to the `IOLoop` + constructor: if ``make_current=True``, the new `IOLoop` will always + try to become current and it raises an error if there is already a + current instance. If ``make_current=False``, the new `IOLoop` will + not try to become current. + + .. versionchanged:: 4.2 + Added the ``make_current`` keyword argument to the `IOLoop` + constructor. """ # Constants from the epoll module _EPOLLIN = 0x001 @@ -133,7 +151,8 @@ class IOLoop(Configurable): Most applications have a single, global `IOLoop` running on the main thread. Use this method to get this instance from - another thread. To get the current thread's `IOLoop`, use `current()`. + another thread. In most other cases, it is better to use `current()` + to get the current thread's `IOLoop`. """ if not hasattr(IOLoop, "_instance"): with IOLoop._instance_lock: @@ -182,8 +201,8 @@ class IOLoop(Configurable): one. .. versionchanged:: 4.1 - Added ``instance`` argument to control the - + Added ``instance`` argument to control the fallback to + `IOLoop.instance()`. """ current = getattr(IOLoop._current, "instance", None) if current is None and instance: @@ -225,8 +244,13 @@ class IOLoop(Configurable): from tornado.platform.select import SelectIOLoop return SelectIOLoop - def initialize(self): - if IOLoop.current(instance=False) is None: + def initialize(self, make_current=None): + if make_current is None: + if IOLoop.current(instance=False) is None: + self.make_current() + elif make_current: + if IOLoop.current(instance=False) is None: + raise RuntimeError("current IOLoop already exists") self.make_current() def close(self, all_fds=False): @@ -393,7 +417,7 @@ class IOLoop(Configurable): # do stuff... if __name__ == '__main__': - IOLoop.instance().run_sync(main) + IOLoop.current().run_sync(main) """ future_cell = [None] @@ -636,8 +660,8 @@ class PollIOLoop(IOLoop): (Linux), `tornado.platform.kqueue.KQueueIOLoop` (BSD and Mac), or `tornado.platform.select.SelectIOLoop` (all platforms). """ - def initialize(self, impl, time_func=None): - super(PollIOLoop, self).initialize() + def initialize(self, impl, time_func=None, **kwargs): + super(PollIOLoop, self).initialize(**kwargs) self._impl = impl if hasattr(self._impl, 'fileno'): set_close_exec(self._impl.fileno()) @@ -742,8 +766,10 @@ class PollIOLoop(IOLoop): # IOLoop is just started once at the beginning. signal.set_wakeup_fd(old_wakeup_fd) old_wakeup_fd = None - except ValueError: # non-main thread - pass + except ValueError: + # Non-main thread, or the previous value of wakeup_fd + # is no longer valid. + old_wakeup_fd = None try: while True: @@ -947,6 +973,11 @@ class PeriodicCallback(object): """Schedules the given callback to be called periodically. The callback is called every ``callback_time`` milliseconds. + Note that the timeout is given in milliseconds, while most other + time-related functions in Tornado use seconds. + + If the callback runs for longer than ``callback_time`` milliseconds, + subsequent invocations will be skipped to get back on schedule. `start` must be called after the `PeriodicCallback` is created. @@ -995,6 +1026,9 @@ class PeriodicCallback(object): def _schedule_next(self): if self._running: current_time = self.io_loop.time() - while self._next_timeout <= current_time: - self._next_timeout += self.callback_time / 1000.0 + + if self._next_timeout <= current_time: + callback_time_sec = self.callback_time / 1000.0 + self._next_timeout += (math.floor((current_time - self._next_timeout) / callback_time_sec) + 1) * callback_time_sec + self._timeout = self.io_loop.add_timeout(self._next_timeout, self._run) diff --git a/tornado/iostream.py b/tornado/iostream.py index cdb6250b9055fed5bb86fab08f3eb87b82b00cf1..3a175a679671ad4feb545e7471da2e4622e77a34 100644 --- a/tornado/iostream.py +++ b/tornado/iostream.py @@ -37,7 +37,7 @@ import re from tornado.concurrent import TracebackFuture from tornado import ioloop from tornado.log import gen_log, app_log -from tornado.netutil import ssl_wrap_socket, ssl_match_hostname, SSLCertificateError +from tornado.netutil import ssl_wrap_socket, ssl_match_hostname, SSLCertificateError, _client_ssl_defaults, _server_ssl_defaults from tornado import stack_context from tornado.util import errno_from_exception @@ -82,7 +82,7 @@ _ERRNO_INPROGRESS = (errno.EINPROGRESS,) if hasattr(errno, "WSAEINPROGRESS"): _ERRNO_INPROGRESS += (errno.WSAEINPROGRESS,) -####################################################### + class StreamClosedError(IOError): """Exception raised by `IOStream` methods when the stream is closed. @@ -169,6 +169,11 @@ class BaseIOStream(object): self._close_callback = None self._connect_callback = None self._connect_future = None + # _ssl_connect_future should be defined in SSLIOStream + # but it's here so we can clean it up in maybe_run_close_callback. + # TODO: refactor that so subclasses can add additional futures + # to be cancelled. + self._ssl_connect_future = None self._connecting = False self._state = None self._pending_callbacks = 0 @@ -317,9 +322,16 @@ class BaseIOStream(object): If a callback is given, it will be run with the data as an argument; if not, this method returns a `.Future`. + Note that if a ``streaming_callback`` is used, data will be + read from the socket as quickly as it becomes available; there + is no way to apply backpressure or cancel the reads. If flow + control or cancellation are desired, use a loop with + `read_bytes(partial=True) <.read_bytes>` instead. + .. versionchanged:: 4.0 The callback argument is now optional and a `.Future` will be returned if it is omitted. + """ future = self._set_read_callback(callback) self._streaming_callback = stack_context.wrap(streaming_callback) @@ -430,9 +442,11 @@ class BaseIOStream(object): if self._connect_future is not None: futures.append(self._connect_future) self._connect_future = None + if self._ssl_connect_future is not None: + futures.append(self._ssl_connect_future) + self._ssl_connect_future = None for future in futures: - if (isinstance(self.error, (socket.error, IOError)) and - errno_from_exception(self.error) in _ERRNO_CONNRESET): + if self._is_connreset(self.error): # Treat connection resets as closed connections so # clients only have to catch one kind of exception # to avoid logging. @@ -655,13 +669,13 @@ class BaseIOStream(object): else: callback = self._read_callback self._read_callback = self._streaming_callback = None - if self._read_future is not None: - assert callback is None - future = self._read_future - self._read_future = None - future.set_result(self._consume(size)) + if self._read_future is not None: + assert callback is None + future = self._read_future + self._read_future = None + future.set_result(self._consume(size)) if callback is not None: - assert self._read_future is None + assert (self._read_future is None) or streaming self._run_callback(callback, self._consume(size)) else: # If we scheduled a callback, we will add the error listener @@ -712,7 +726,7 @@ class BaseIOStream(object): chunk = self.read_from_fd() except (socket.error, IOError, OSError) as e: # ssl.SSLError is a subclass of socket.error - if e.args[0] in _ERRNO_CONNRESET: + if self._is_connreset(e): # Treat ECONNRESET as a connection close rather than # an error to minimize log spam (the exception will # be available on self.error for apps that care). @@ -834,7 +848,7 @@ class BaseIOStream(object): self._write_buffer_frozen = True break else: - if e.args[0] not in _ERRNO_CONNRESET: + if not self._is_connreset(e): # Broken pipe errors are usually caused by connection # reset, and its better to not log EPIPE errors to # minimize log spam @@ -912,6 +926,14 @@ class BaseIOStream(object): self._state = self._state | state self.io_loop.update_handler(self.fileno(), self._state) + def _is_connreset(self, exc): + """Return true if exc is ECONNRESET or equivalent. + + May be overridden in subclasses. + """ + return (isinstance(exc, (socket.error, IOError)) and + errno_from_exception(exc) in _ERRNO_CONNRESET) + class IOStream(BaseIOStream): r"""Socket-based `IOStream` implementation. @@ -926,7 +948,9 @@ class IOStream(BaseIOStream): connected before passing it to the `IOStream` or connected with `IOStream.connect`. - A very simple (and broken) HTTP client using this class:: + A very simple (and broken) HTTP client using this class: + + .. testcode:: import tornado.ioloop import tornado.iostream @@ -945,14 +969,19 @@ class IOStream(BaseIOStream): stream.read_bytes(int(headers[b"Content-Length"]), on_body) def on_body(data): - print data + print(data) stream.close() - tornado.ioloop.IOLoop.instance().stop() + tornado.ioloop.IOLoop.current().stop() + + if __name__ == '__main__': + s = socket.socket(socket.AF_INET, socket.SOCK_STREAM, 0) + stream = tornado.iostream.IOStream(s) + stream.connect(("friendfeed.com", 80), send_request) + tornado.ioloop.IOLoop.current().start() + + .. testoutput:: + :hide: - s = socket.socket(socket.AF_INET, socket.SOCK_STREAM, 0) - stream = tornado.iostream.IOStream(s) - stream.connect(("friendfeed.com", 80), send_request) - tornado.ioloop.IOLoop.instance().start() """ def __init__(self, socket, *args, **kwargs): self.socket = socket @@ -1006,10 +1035,10 @@ class IOStream(BaseIOStream): returns a `.Future` (whose result after a successful connection will be the stream itself). - If specified, the ``server_hostname`` parameter will be used - in SSL connections for certificate validation (if requested in - the ``ssl_options``) and SNI (if supported; requires - Python 3.2+). + In SSL mode, the ``server_hostname`` parameter will be used + for certificate validation (unless disabled in the + ``ssl_options``) and SNI (if supported; requires Python + 2.7.9+). Note that it is safe to call `IOStream.write <BaseIOStream.write>` while the connection is pending, in @@ -1020,6 +1049,11 @@ class IOStream(BaseIOStream): .. versionchanged:: 4.0 If no callback is given, returns a `.Future`. + .. versionchanged:: 4.2 + SSL certificates are validated by default; pass + ``ssl_options=dict(cert_reqs=ssl.CERT_NONE)`` or a + suitably-configured `ssl.SSLContext` to the + `SSLIOStream` constructor to disable. """ self._connecting = True if callback is not None: @@ -1062,10 +1096,11 @@ class IOStream(BaseIOStream): data. It can also be used immediately after connecting, before any reads or writes. - The ``ssl_options`` argument may be either a dictionary - of options or an `ssl.SSLContext`. If a ``server_hostname`` - is given, it will be used for certificate verification - (as configured in the ``ssl_options``). + The ``ssl_options`` argument may be either an `ssl.SSLContext` + object or a dictionary of keyword arguments for the + `ssl.wrap_socket` function. The ``server_hostname`` argument + will be used for certificate validation unless disabled + in the ``ssl_options``. This method returns a `.Future` whose result is the new `SSLIOStream`. After this method has been called, @@ -1075,6 +1110,11 @@ class IOStream(BaseIOStream): transferred to the new stream. .. versionadded:: 4.0 + + .. versionchanged:: 4.2 + SSL certificates are validated by default; pass + ``ssl_options=dict(cert_reqs=ssl.CERT_NONE)`` or a + suitably-configured `ssl.SSLContext` to disable. """ if (self._read_callback or self._read_future or self._write_callback or self._write_future or @@ -1083,7 +1123,10 @@ class IOStream(BaseIOStream): self._read_buffer or self._write_buffer): raise ValueError("IOStream is not idle; cannot convert to SSL") if ssl_options is None: - ssl_options = {} + if server_side: + ssl_options = _server_ssl_defaults + else: + ssl_options = _client_ssl_defaults socket = self.socket self.io_loop.remove_handler(socket) @@ -1102,6 +1145,7 @@ class IOStream(BaseIOStream): # If we had an "unwrap" counterpart to this method we would need # to restore the original callback after our Future resolves # so that repeated wrap/unwrap calls don't build up layers. + def close_callback(): if not future.done(): future.set_exception(ssl_stream.error or StreamClosedError()) @@ -1146,7 +1190,7 @@ class IOStream(BaseIOStream): # Sometimes setsockopt will fail if the socket is closed # at the wrong time. This can happen with HTTPServer # resetting the value to false between requests. - if e.errno not in (errno.EINVAL, errno.ECONNRESET): + if e.errno != errno.EINVAL and not self._is_connreset(e): raise @@ -1162,11 +1206,11 @@ class SSLIOStream(IOStream): wrapped when `IOStream.connect` is finished. """ def __init__(self, *args, **kwargs): - """The ``ssl_options`` keyword argument may either be a dictionary - of keywords arguments for `ssl.wrap_socket`, or an `ssl.SSLContext` - object. + """The ``ssl_options`` keyword argument may either be an + `ssl.SSLContext` object or a dictionary of keywords arguments + for `ssl.wrap_socket` """ - self._ssl_options = kwargs.pop('ssl_options', {}) + self._ssl_options = kwargs.pop('ssl_options', _client_ssl_defaults) super(SSLIOStream, self).__init__(*args, **kwargs) self._ssl_accepting = True self._handshake_reading = False @@ -1221,8 +1265,7 @@ class SSLIOStream(IOStream): # to cause do_handshake to raise EBADF, so make that error # quiet as well. # https://groups.google.com/forum/?fromgroups#!topic/python-tornado/ApucKJat1_0 - if (err.args[0] in _ERRNO_CONNRESET or - err.args[0] == errno.EBADF): + if self._is_connreset(err) or err.args[0] == errno.EBADF: return self.close(exc_info=True) raise except AttributeError: @@ -1235,10 +1278,17 @@ class SSLIOStream(IOStream): if not self._verify_cert(self.socket.getpeercert()): self.close() return - if self._ssl_connect_callback is not None: - callback = self._ssl_connect_callback - self._ssl_connect_callback = None - self._run_callback(callback) + self._run_ssl_connect_callback() + + def _run_ssl_connect_callback(self): + if self._ssl_connect_callback is not None: + callback = self._ssl_connect_callback + self._ssl_connect_callback = None + self._run_callback(callback) + if self._ssl_connect_future is not None: + future = self._ssl_connect_future + self._ssl_connect_future = None + future.set_result(self) def _verify_cert(self, peercert): """Returns True if peercert is valid according to the configured @@ -1280,14 +1330,11 @@ class SSLIOStream(IOStream): super(SSLIOStream, self)._handle_write() def connect(self, address, callback=None, server_hostname=None): - # Save the user's callback and run it after the ssl handshake - # has completed. - self._ssl_connect_callback = stack_context.wrap(callback) self._server_hostname = server_hostname - # Note: Since we don't pass our callback argument along to - # super.connect(), this will always return a Future. - # This is harmless, but a bit less efficient than it could be. - return super(SSLIOStream, self).connect(address, callback=None) + # Pass a dummy callback to super.connect(), which is slightly + # more efficient than letting it return a Future we ignore. + super(SSLIOStream, self).connect(address, callback=lambda: None) + return self.wait_for_handshake(callback) def _handle_connect(self): # Call the superclass method to check for errors. @@ -1312,6 +1359,51 @@ class SSLIOStream(IOStream): do_handshake_on_connect=False) self._add_io_state(old_state) + def wait_for_handshake(self, callback=None): + """Wait for the initial SSL handshake to complete. + + If a ``callback`` is given, it will be called with no + arguments once the handshake is complete; otherwise this + method returns a `.Future` which will resolve to the + stream itself after the handshake is complete. + + Once the handshake is complete, information such as + the peer's certificate and NPN/ALPN selections may be + accessed on ``self.socket``. + + This method is intended for use on server-side streams + or after using `IOStream.start_tls`; it should not be used + with `IOStream.connect` (which already waits for the + handshake to complete). It may only be called once per stream. + + .. versionadded:: 4.2 + """ + if (self._ssl_connect_callback is not None or + self._ssl_connect_future is not None): + raise RuntimeError("Already waiting") + if callback is not None: + self._ssl_connect_callback = stack_context.wrap(callback) + future = None + else: + future = self._ssl_connect_future = TracebackFuture() + if not self._ssl_accepting: + self._run_ssl_connect_callback() + return future + + def write_to_fd(self, data): + try: + return self.socket.send(data) + except ssl.SSLError as e: + if e.args[0] == ssl.SSL_ERROR_WANT_WRITE: + # In Python 3.5+, SSLSocket.send raises a WANT_WRITE error if + # the socket is not writeable; we need to transform this into + # an EWOULDBLOCK socket.error or a zero return value, + # either of which will be recognized by the caller of this + # method. Prior to Python 3.5, an unwriteable socket would + # simply return 0 bytes written. + return 0 + raise + def read_from_fd(self): if self._ssl_accepting: # If the handshake hasn't finished yet, there can't be anything @@ -1342,6 +1434,11 @@ class SSLIOStream(IOStream): return None return chunk + def _is_connreset(self, e): + if isinstance(e, ssl.SSLError) and e.args[0] == ssl.SSL_ERROR_EOF: + return True + return super(SSLIOStream, self)._is_connreset(e) + class PipeIOStream(BaseIOStream): """Pipe-based `IOStream` implementation. diff --git a/tornado/locale.py b/tornado/locale.py index 07c6d582b4e6a4093f71485539df5fd66320eaee..a668765bbc4c6a57f8fd24120a63f03bad8f84d8 100644 --- a/tornado/locale.py +++ b/tornado/locale.py @@ -55,6 +55,7 @@ _default_locale = "en_US" _translations = {} _supported_locales = frozenset([_default_locale]) _use_gettext = False +CONTEXT_SEPARATOR = "\x04" def get(*locale_codes): @@ -273,6 +274,9 @@ class Locale(object): """ raise NotImplementedError() + def pgettext(self, context, message, plural_message=None, count=None): + raise NotImplementedError() + def format_date(self, date, gmt_offset=0, relative=True, shorter=False, full_format=False): """Formats the given date (which should be GMT). @@ -422,6 +426,11 @@ class CSVLocale(Locale): message_dict = self.translations.get("unknown", {}) return message_dict.get(message, message) + def pgettext(self, context, message, plural_message=None, count=None): + if self.translations: + gen_log.warning('pgettext is not supported by CSVLocale') + return self.translate(message, plural_message, count) + class GettextLocale(Locale): """Locale implementation using the `gettext` module.""" @@ -445,6 +454,44 @@ class GettextLocale(Locale): else: return self.gettext(message) + def pgettext(self, context, message, plural_message=None, count=None): + """Allows to set context for translation, accepts plural forms. + + Usage example:: + + pgettext("law", "right") + pgettext("good", "right") + + Plural message example:: + + pgettext("organization", "club", "clubs", len(clubs)) + pgettext("stick", "club", "clubs", len(clubs)) + + To generate POT file with context, add following options to step 1 + of `load_gettext_translations` sequence:: + + xgettext [basic options] --keyword=pgettext:1c,2 --keyword=pgettext:1c,2,3 + + .. versionadded:: 4.2 + """ + if plural_message is not None: + assert count is not None + msgs_with_ctxt = ("%s%s%s" % (context, CONTEXT_SEPARATOR, message), + "%s%s%s" % (context, CONTEXT_SEPARATOR, plural_message), + count) + result = self.ngettext(*msgs_with_ctxt) + if CONTEXT_SEPARATOR in result: + # Translation not found + result = self.ngettext(message, plural_message, count) + return result + else: + msg_with_ctxt = "%s%s%s" % (context, CONTEXT_SEPARATOR, message) + result = self.gettext(msg_with_ctxt) + if CONTEXT_SEPARATOR in result: + # Translation not found + result = message + return result + LOCALE_NAMES = { "af_ZA": {"name_en": u("Afrikaans"), "name": u("Afrikaans")}, "am_ET": {"name_en": u("Amharic"), "name": u('\u12a0\u121b\u122d\u129b')}, diff --git a/tornado/locks.py b/tornado/locks.py new file mode 100644 index 0000000000000000000000000000000000000000..4b0bdb38f1ea148bfab350a2b5590602d3a876f3 --- /dev/null +++ b/tornado/locks.py @@ -0,0 +1,460 @@ +# Copyright 2015 The Tornado Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +""" +.. testsetup:: * + + from tornado import ioloop, gen, locks + io_loop = ioloop.IOLoop.current() +""" + +from __future__ import absolute_import, division, print_function, with_statement + +__all__ = ['Condition', 'Event', 'Semaphore', 'BoundedSemaphore', 'Lock'] + +import collections + +from tornado import gen, ioloop +from tornado.concurrent import Future + + +class _TimeoutGarbageCollector(object): + """Base class for objects that periodically clean up timed-out waiters. + + Avoids memory leak in a common pattern like: + + while True: + yield condition.wait(short_timeout) + print('looping....') + """ + def __init__(self): + self._waiters = collections.deque() # Futures. + self._timeouts = 0 + + def _garbage_collect(self): + # Occasionally clear timed-out waiters. + self._timeouts += 1 + if self._timeouts > 100: + self._timeouts = 0 + self._waiters = collections.deque( + w for w in self._waiters if not w.done()) + + +class Condition(_TimeoutGarbageCollector): + """A condition allows one or more coroutines to wait until notified. + + Like a standard `threading.Condition`, but does not need an underlying lock + that is acquired and released. + + With a `Condition`, coroutines can wait to be notified by other coroutines: + + .. testcode:: + + condition = locks.Condition() + + @gen.coroutine + def waiter(): + print("I'll wait right here") + yield condition.wait() # Yield a Future. + print("I'm done waiting") + + @gen.coroutine + def notifier(): + print("About to notify") + condition.notify() + print("Done notifying") + + @gen.coroutine + def runner(): + # Yield two Futures; wait for waiter() and notifier() to finish. + yield [waiter(), notifier()] + + io_loop.run_sync(runner) + + .. testoutput:: + + I'll wait right here + About to notify + Done notifying + I'm done waiting + + `wait` takes an optional ``timeout`` argument, which is either an absolute + timestamp:: + + io_loop = ioloop.IOLoop.current() + + # Wait up to 1 second for a notification. + yield condition.wait(timeout=io_loop.time() + 1) + + ...or a `datetime.timedelta` for a timeout relative to the current time:: + + # Wait up to 1 second. + yield condition.wait(timeout=datetime.timedelta(seconds=1)) + + The method raises `tornado.gen.TimeoutError` if there's no notification + before the deadline. + """ + + def __init__(self): + super(Condition, self).__init__() + self.io_loop = ioloop.IOLoop.current() + + def __repr__(self): + result = '<%s' % (self.__class__.__name__, ) + if self._waiters: + result += ' waiters[%s]' % len(self._waiters) + return result + '>' + + def wait(self, timeout=None): + """Wait for `.notify`. + + Returns a `.Future` that resolves ``True`` if the condition is notified, + or ``False`` after a timeout. + """ + waiter = Future() + self._waiters.append(waiter) + if timeout: + def on_timeout(): + waiter.set_result(False) + self._garbage_collect() + io_loop = ioloop.IOLoop.current() + timeout_handle = io_loop.add_timeout(timeout, on_timeout) + waiter.add_done_callback( + lambda _: io_loop.remove_timeout(timeout_handle)) + return waiter + + def notify(self, n=1): + """Wake ``n`` waiters.""" + waiters = [] # Waiters we plan to run right now. + while n and self._waiters: + waiter = self._waiters.popleft() + if not waiter.done(): # Might have timed out. + n -= 1 + waiters.append(waiter) + + for waiter in waiters: + waiter.set_result(True) + + def notify_all(self): + """Wake all waiters.""" + self.notify(len(self._waiters)) + + +class Event(object): + """An event blocks coroutines until its internal flag is set to True. + + Similar to `threading.Event`. + + A coroutine can wait for an event to be set. Once it is set, calls to + ``yield event.wait()`` will not block unless the event has been cleared: + + .. testcode:: + + event = locks.Event() + + @gen.coroutine + def waiter(): + print("Waiting for event") + yield event.wait() + print("Not waiting this time") + yield event.wait() + print("Done") + + @gen.coroutine + def setter(): + print("About to set the event") + event.set() + + @gen.coroutine + def runner(): + yield [waiter(), setter()] + + io_loop.run_sync(runner) + + .. testoutput:: + + Waiting for event + About to set the event + Not waiting this time + Done + """ + def __init__(self): + self._future = Future() + + def __repr__(self): + return '<%s %s>' % ( + self.__class__.__name__, 'set' if self.is_set() else 'clear') + + def is_set(self): + """Return ``True`` if the internal flag is true.""" + return self._future.done() + + def set(self): + """Set the internal flag to ``True``. All waiters are awakened. + + Calling `.wait` once the flag is set will not block. + """ + if not self._future.done(): + self._future.set_result(None) + + def clear(self): + """Reset the internal flag to ``False``. + + Calls to `.wait` will block until `.set` is called. + """ + if self._future.done(): + self._future = Future() + + def wait(self, timeout=None): + """Block until the internal flag is true. + + Returns a Future, which raises `tornado.gen.TimeoutError` after a + timeout. + """ + if timeout is None: + return self._future + else: + return gen.with_timeout(timeout, self._future) + + +class _ReleasingContextManager(object): + """Releases a Lock or Semaphore at the end of a "with" statement. + + with (yield semaphore.acquire()): + pass + + # Now semaphore.release() has been called. + """ + def __init__(self, obj): + self._obj = obj + + def __enter__(self): + pass + + def __exit__(self, exc_type, exc_val, exc_tb): + self._obj.release() + + +class Semaphore(_TimeoutGarbageCollector): + """A lock that can be acquired a fixed number of times before blocking. + + A Semaphore manages a counter representing the number of `.release` calls + minus the number of `.acquire` calls, plus an initial value. The `.acquire` + method blocks if necessary until it can return without making the counter + negative. + + Semaphores limit access to a shared resource. To allow access for two + workers at a time: + + .. testsetup:: semaphore + + from collections import deque + + from tornado import gen, ioloop + from tornado.concurrent import Future + + # Ensure reliable doctest output: resolve Futures one at a time. + futures_q = deque([Future() for _ in range(3)]) + + @gen.coroutine + def simulator(futures): + for f in futures: + yield gen.moment + f.set_result(None) + + ioloop.IOLoop.current().add_callback(simulator, list(futures_q)) + + def use_some_resource(): + return futures_q.popleft() + + .. testcode:: semaphore + + sem = locks.Semaphore(2) + + @gen.coroutine + def worker(worker_id): + yield sem.acquire() + try: + print("Worker %d is working" % worker_id) + yield use_some_resource() + finally: + print("Worker %d is done" % worker_id) + sem.release() + + @gen.coroutine + def runner(): + # Join all workers. + yield [worker(i) for i in range(3)] + + io_loop.run_sync(runner) + + .. testoutput:: semaphore + + Worker 0 is working + Worker 1 is working + Worker 0 is done + Worker 2 is working + Worker 1 is done + Worker 2 is done + + Workers 0 and 1 are allowed to run concurrently, but worker 2 waits until + the semaphore has been released once, by worker 0. + + `.acquire` is a context manager, so ``worker`` could be written as:: + + @gen.coroutine + def worker(worker_id): + with (yield sem.acquire()): + print("Worker %d is working" % worker_id) + yield use_some_resource() + + # Now the semaphore has been released. + print("Worker %d is done" % worker_id) + """ + def __init__(self, value=1): + super(Semaphore, self).__init__() + if value < 0: + raise ValueError('semaphore initial value must be >= 0') + + self._value = value + + def __repr__(self): + res = super(Semaphore, self).__repr__() + extra = 'locked' if self._value == 0 else 'unlocked,value:{0}'.format( + self._value) + if self._waiters: + extra = '{0},waiters:{1}'.format(extra, len(self._waiters)) + return '<{0} [{1}]>'.format(res[1:-1], extra) + + def release(self): + """Increment the counter and wake one waiter.""" + self._value += 1 + while self._waiters: + waiter = self._waiters.popleft() + if not waiter.done(): + self._value -= 1 + + # If the waiter is a coroutine paused at + # + # with (yield semaphore.acquire()): + # + # then the context manager's __exit__ calls release() at the end + # of the "with" block. + waiter.set_result(_ReleasingContextManager(self)) + break + + def acquire(self, timeout=None): + """Decrement the counter. Returns a Future. + + Block if the counter is zero and wait for a `.release`. The Future + raises `.TimeoutError` after the deadline. + """ + waiter = Future() + if self._value > 0: + self._value -= 1 + waiter.set_result(_ReleasingContextManager(self)) + else: + self._waiters.append(waiter) + if timeout: + def on_timeout(): + waiter.set_exception(gen.TimeoutError()) + self._garbage_collect() + io_loop = ioloop.IOLoop.current() + timeout_handle = io_loop.add_timeout(timeout, on_timeout) + waiter.add_done_callback( + lambda _: io_loop.remove_timeout(timeout_handle)) + return waiter + + def __enter__(self): + raise RuntimeError( + "Use Semaphore like 'with (yield semaphore.acquire())', not like" + " 'with semaphore'") + + __exit__ = __enter__ + + +class BoundedSemaphore(Semaphore): + """A semaphore that prevents release() being called too many times. + + If `.release` would increment the semaphore's value past the initial + value, it raises `ValueError`. Semaphores are mostly used to guard + resources with limited capacity, so a semaphore released too many times + is a sign of a bug. + """ + def __init__(self, value=1): + super(BoundedSemaphore, self).__init__(value=value) + self._initial_value = value + + def release(self): + """Increment the counter and wake one waiter.""" + if self._value >= self._initial_value: + raise ValueError("Semaphore released too many times") + super(BoundedSemaphore, self).release() + + +class Lock(object): + """A lock for coroutines. + + A Lock begins unlocked, and `acquire` locks it immediately. While it is + locked, a coroutine that yields `acquire` waits until another coroutine + calls `release`. + + Releasing an unlocked lock raises `RuntimeError`. + + `acquire` supports the context manager protocol: + + >>> from tornado import gen, locks + >>> lock = locks.Lock() + >>> + >>> @gen.coroutine + ... def f(): + ... with (yield lock.acquire()): + ... # Do something holding the lock. + ... pass + ... + ... # Now the lock is released. + """ + def __init__(self): + self._block = BoundedSemaphore(value=1) + + def __repr__(self): + return "<%s _block=%s>" % ( + self.__class__.__name__, + self._block) + + def acquire(self, timeout=None): + """Attempt to lock. Returns a Future. + + Returns a Future, which raises `tornado.gen.TimeoutError` after a + timeout. + """ + return self._block.acquire(timeout) + + def release(self): + """Unlock. + + The first coroutine in line waiting for `acquire` gets the lock. + + If not locked, raise a `RuntimeError`. + """ + try: + self._block.release() + except ValueError: + raise RuntimeError('release unlocked lock') + + def __enter__(self): + raise RuntimeError( + "Use Lock like 'with (yield lock)', not like 'with lock'") + + __exit__ = __enter__ diff --git a/tornado/log.py b/tornado/log.py index 374071d419ddde5b590f0e865aefa5638ab2733e..c68dec46bad5f1331fb121d6f3c0d43047e50484 100644 --- a/tornado/log.py +++ b/tornado/log.py @@ -206,6 +206,14 @@ def enable_pretty_logging(options=None, logger=None): def define_logging_options(options=None): + """Add logging-related flags to ``options``. + + These options are present automatically on the default options instance; + this method is only necessary if you have created your own `.OptionParser`. + + .. versionadded:: 4.2 + This function existed in prior versions but was broken and undocumented until 4.2. + """ if options is None: # late import to prevent cycle from tornado.options import options @@ -227,4 +235,4 @@ def define_logging_options(options=None): options.define("log_file_num_backups", type=int, default=10, help="number of log files to keep") - options.add_parse_callback(enable_pretty_logging) + options.add_parse_callback(lambda: enable_pretty_logging(options)) diff --git a/tornado/netutil.py b/tornado/netutil.py index 17e9580405d664b7e9d4cc0b13ac30c405fa63f4..9aa292c41729d45c9dc7e046df5a3ba892d9220f 100644 --- a/tornado/netutil.py +++ b/tornado/netutil.py @@ -35,6 +35,15 @@ except ImportError: # ssl is not available on Google App Engine ssl = None +try: + import certifi +except ImportError: + # certifi is optional as long as we have ssl.create_default_context. + if ssl is None or hasattr(ssl, 'create_default_context'): + certifi = None + else: + raise + try: xrange # py2 except NameError: @@ -50,6 +59,38 @@ else: ssl_match_hostname = backports.ssl_match_hostname.match_hostname SSLCertificateError = backports.ssl_match_hostname.CertificateError +if hasattr(ssl, 'SSLContext'): + if hasattr(ssl, 'create_default_context'): + # Python 2.7.9+, 3.4+ + # Note that the naming of ssl.Purpose is confusing; the purpose + # of a context is to authentiate the opposite side of the connection. + _client_ssl_defaults = ssl.create_default_context( + ssl.Purpose.SERVER_AUTH) + _server_ssl_defaults = ssl.create_default_context( + ssl.Purpose.CLIENT_AUTH) + else: + # Python 3.2-3.3 + _client_ssl_defaults = ssl.SSLContext(ssl.PROTOCOL_SSLv23) + _client_ssl_defaults.verify_mode = ssl.CERT_REQUIRED + _client_ssl_defaults.load_verify_locations(certifi.where()) + _server_ssl_defaults = ssl.SSLContext(ssl.PROTOCOL_SSLv23) + if hasattr(ssl, 'OP_NO_COMPRESSION'): + # Disable TLS compression to avoid CRIME and related attacks. + # This constant wasn't added until python 3.3. + _client_ssl_defaults.options |= ssl.OP_NO_COMPRESSION + _server_ssl_defaults.options |= ssl.OP_NO_COMPRESSION + +elif ssl: + # Python 2.6-2.7.8 + _client_ssl_defaults = dict(cert_reqs=ssl.CERT_REQUIRED, + ca_certs=certifi.where()) + _server_ssl_defaults = {} +else: + # Google App Engine + _client_ssl_defaults = dict(cert_reqs=None, + ca_certs=None) + _server_ssl_defaults = {} + # ThreadedResolver runs getaddrinfo on a thread. If the hostname is unicode, # getaddrinfo attempts to import encodings.idna. If this is done at # module-import time, the import lock is already held by the main thread, @@ -68,6 +109,7 @@ if hasattr(errno, "WSAEWOULDBLOCK"): # Default backlog used when calling sock.listen() _DEFAULT_BACKLOG = 128 + def bind_sockets(port, address=None, family=socket.AF_UNSPEC, backlog=_DEFAULT_BACKLOG, flags=None): """Creates listening sockets bound to the given port and address. @@ -419,7 +461,7 @@ def ssl_options_to_context(ssl_options): `~ssl.SSLContext` object. The ``ssl_options`` dictionary contains keywords to be passed to - `ssl.wrap_socket`. In Python 3.2+, `ssl.SSLContext` objects can + `ssl.wrap_socket`. In Python 2.7.9+, `ssl.SSLContext` objects can be used instead. This function converts the dict form to its `~ssl.SSLContext` equivalent, and may be used when a component which accepts both forms needs to upgrade to the `~ssl.SSLContext` version @@ -450,11 +492,11 @@ def ssl_options_to_context(ssl_options): def ssl_wrap_socket(socket, ssl_options, server_hostname=None, **kwargs): """Returns an ``ssl.SSLSocket`` wrapping the given socket. - ``ssl_options`` may be either a dictionary (as accepted by - `ssl_options_to_context`) or an `ssl.SSLContext` object. - Additional keyword arguments are passed to ``wrap_socket`` - (either the `~ssl.SSLContext` method or the `ssl` module function - as appropriate). + ``ssl_options`` may be either an `ssl.SSLContext` object or a + dictionary (as accepted by `ssl_options_to_context`). Additional + keyword arguments are passed to ``wrap_socket`` (either the + `~ssl.SSLContext` method or the `ssl` module function as + appropriate). """ context = ssl_options_to_context(ssl_options) if hasattr(ssl, 'SSLContext') and isinstance(context, ssl.SSLContext): diff --git a/tornado/options.py b/tornado/options.py index c855407c295ea475f864eae9497402192aa24a14..89a9e4326539e5d4d330ab1aea2b74d138f97cf3 100644 --- a/tornado/options.py +++ b/tornado/options.py @@ -256,7 +256,7 @@ class OptionParser(object): arg = args[i].lstrip("-") name, equals, value = arg.partition("=") name = name.replace('-', '_') - if not name in self._options: + if name not in self._options: self.print_help() raise Error('Unrecognized command line option: %r' % name) option = self._options[name] diff --git a/tornado/platform/asyncio.py b/tornado/platform/asyncio.py index bc6851750ac119e8e22de79e9a5ffa3bc000dde5..8f3dbff640008bf0871b8c6e6462fe14dd08740a 100644 --- a/tornado/platform/asyncio.py +++ b/tornado/platform/asyncio.py @@ -29,8 +29,10 @@ except ImportError as e: # Re-raise the original asyncio error, not the trollius one. raise e + class BaseAsyncIOLoop(IOLoop): - def initialize(self, asyncio_loop, close_loop=False): + def initialize(self, asyncio_loop, close_loop=False, **kwargs): + super(BaseAsyncIOLoop, self).initialize(**kwargs) self.asyncio_loop = asyncio_loop self.close_loop = close_loop self.asyncio_loop.call_soon(self.make_current) @@ -131,15 +133,16 @@ class BaseAsyncIOLoop(IOLoop): class AsyncIOMainLoop(BaseAsyncIOLoop): - def initialize(self): + def initialize(self, **kwargs): super(AsyncIOMainLoop, self).initialize(asyncio.get_event_loop(), - close_loop=False) + close_loop=False, **kwargs) class AsyncIOLoop(BaseAsyncIOLoop): - def initialize(self): + def initialize(self, **kwargs): super(AsyncIOLoop, self).initialize(asyncio.new_event_loop(), - close_loop=True) + close_loop=True, **kwargs) + def to_tornado_future(asyncio_future): """Convert an ``asyncio.Future`` to a `tornado.concurrent.Future`.""" @@ -147,6 +150,7 @@ def to_tornado_future(asyncio_future): tornado.concurrent.chain_future(asyncio_future, tf) return tf + def to_asyncio_future(tornado_future): """Convert a `tornado.concurrent.Future` to an ``asyncio.Future``.""" af = asyncio.Future() diff --git a/tornado/platform/auto.py b/tornado/platform/auto.py index ddfe06b4a5e2c33ea2b12f90ce3196c53bf4b09a..fc40c9d973321a4c8e82f89adb6da5ed445b12d8 100644 --- a/tornado/platform/auto.py +++ b/tornado/platform/auto.py @@ -27,13 +27,14 @@ from __future__ import absolute_import, division, print_function, with_statement import os -if os.name == 'nt': - from tornado.platform.common import Waker - from tornado.platform.windows import set_close_exec -elif 'APPENGINE_RUNTIME' in os.environ: +if 'APPENGINE_RUNTIME' in os.environ: from tornado.platform.common import Waker + def set_close_exec(fd): pass +elif os.name == 'nt': + from tornado.platform.common import Waker + from tornado.platform.windows import set_close_exec else: from tornado.platform.posix import set_close_exec, Waker @@ -41,9 +42,13 @@ try: # monotime monkey-patches the time module to have a monotonic function # in versions of python before 3.3. import monotime + # Silence pyflakes warning about this unused import + monotime except ImportError: pass try: from time import monotonic as monotonic_time except ImportError: monotonic_time = None + +__all__ = ['Waker', 'set_close_exec', 'monotonic_time'] diff --git a/tornado/platform/select.py b/tornado/platform/select.py index 1e1265547ce7f396f070e1125ec171dcb51c1b80..db52ef91063bf813f8ed39f5fe2ef4d803429e72 100644 --- a/tornado/platform/select.py +++ b/tornado/platform/select.py @@ -47,7 +47,7 @@ class _Select(object): # Closed connections are reported as errors by epoll and kqueue, # but as zero-byte reads by select, so when errors are requested # we need to listen for both read and error. - #self.read_fds.add(fd) + # self.read_fds.add(fd) def modify(self, fd, events): self.unregister(fd) diff --git a/tornado/platform/twisted.py b/tornado/platform/twisted.py index 09b328366b6796dccaa69f7d40f59794fd814dd5..7b3c8ca5e39f26ae954222bbfb398c58b1a864a3 100644 --- a/tornado/platform/twisted.py +++ b/tornado/platform/twisted.py @@ -35,7 +35,7 @@ of the application:: tornado.platform.twisted.install() from twisted.internet import reactor -When the app is ready to start, call `IOLoop.instance().start()` +When the app is ready to start, call `IOLoop.current().start()` instead of `reactor.run()`. It is also possible to create a non-global reactor by calling @@ -416,7 +416,8 @@ class TwistedIOLoop(tornado.ioloop.IOLoop): because the ``SIGCHLD`` handlers used by Tornado and Twisted conflict with each other. """ - def initialize(self, reactor=None): + def initialize(self, reactor=None, **kwargs): + super(TwistedIOLoop, self).initialize(**kwargs) if reactor is None: import twisted.internet.reactor reactor = twisted.internet.reactor @@ -572,6 +573,7 @@ if hasattr(gen.convert_yielded, 'register'): @gen.convert_yielded.register(Deferred) def _(d): f = Future() + def errback(failure): try: failure.raiseException() diff --git a/tornado/process.py b/tornado/process.py index 3790ca0a55f99c591f02aabe18a9acdd9a2f2afc..f580e19253340ac892f502043ad7bf4cacdbfd0a 100644 --- a/tornado/process.py +++ b/tornado/process.py @@ -29,6 +29,7 @@ import time from binascii import hexlify +from tornado.concurrent import Future from tornado import ioloop from tornado.iostream import PipeIOStream from tornado.log import gen_log @@ -48,6 +49,10 @@ except NameError: long = int # py3 +# Re-export this exception for convenience. +CalledProcessError = subprocess.CalledProcessError + + def cpu_count(): """Returns the number of processors on this machine.""" if multiprocessing is None: @@ -258,6 +263,33 @@ class Subprocess(object): Subprocess._waiting[self.pid] = self Subprocess._try_cleanup_process(self.pid) + def wait_for_exit(self, raise_error=True): + """Returns a `.Future` which resolves when the process exits. + + Usage:: + + ret = yield proc.wait_for_exit() + + This is a coroutine-friendly alternative to `set_exit_callback` + (and a replacement for the blocking `subprocess.Popen.wait`). + + By default, raises `subprocess.CalledProcessError` if the process + has a non-zero exit status. Use ``wait_for_exit(raise_error=False)`` + to suppress this behavior and return the exit status without raising. + + .. versionadded:: 4.2 + """ + future = Future() + + def callback(ret): + if ret != 0 and raise_error: + # Unfortunately we don't have the original args any more. + future.set_exception(CalledProcessError(ret, None)) + else: + future.set_result(ret) + self.set_exit_callback(callback) + return future + @classmethod def initialize(cls, io_loop=None): """Initializes the ``SIGCHLD`` handler. diff --git a/tornado/queues.py b/tornado/queues.py new file mode 100644 index 0000000000000000000000000000000000000000..55ab4834ed244d722305f4e6d4b84c7629e7da6d --- /dev/null +++ b/tornado/queues.py @@ -0,0 +1,321 @@ +# Copyright 2015 The Tornado Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +from __future__ import absolute_import, division, print_function, with_statement + +__all__ = ['Queue', 'PriorityQueue', 'LifoQueue', 'QueueFull', 'QueueEmpty'] + +import collections +import heapq + +from tornado import gen, ioloop +from tornado.concurrent import Future +from tornado.locks import Event + + +class QueueEmpty(Exception): + """Raised by `.Queue.get_nowait` when the queue has no items.""" + pass + + +class QueueFull(Exception): + """Raised by `.Queue.put_nowait` when a queue is at its maximum size.""" + pass + + +def _set_timeout(future, timeout): + if timeout: + def on_timeout(): + future.set_exception(gen.TimeoutError()) + io_loop = ioloop.IOLoop.current() + timeout_handle = io_loop.add_timeout(timeout, on_timeout) + future.add_done_callback( + lambda _: io_loop.remove_timeout(timeout_handle)) + + +class Queue(object): + """Coordinate producer and consumer coroutines. + + If maxsize is 0 (the default) the queue size is unbounded. + + .. testcode:: + + q = queues.Queue(maxsize=2) + + @gen.coroutine + def consumer(): + while True: + item = yield q.get() + try: + print('Doing work on %s' % item) + yield gen.sleep(0.01) + finally: + q.task_done() + + @gen.coroutine + def producer(): + for item in range(5): + yield q.put(item) + print('Put %s' % item) + + @gen.coroutine + def main(): + consumer() # Start consumer. + yield producer() # Wait for producer to put all tasks. + yield q.join() # Wait for consumer to finish all tasks. + print('Done') + + io_loop.run_sync(main) + + .. testoutput:: + + Put 0 + Put 1 + Put 2 + Doing work on 0 + Doing work on 1 + Put 3 + Doing work on 2 + Put 4 + Doing work on 3 + Doing work on 4 + Done + """ + def __init__(self, maxsize=0): + if maxsize is None: + raise TypeError("maxsize can't be None") + + if maxsize < 0: + raise ValueError("maxsize can't be negative") + + self._maxsize = maxsize + self._init() + self._getters = collections.deque([]) # Futures. + self._putters = collections.deque([]) # Pairs of (item, Future). + self._unfinished_tasks = 0 + self._finished = Event() + self._finished.set() + + @property + def maxsize(self): + """Number of items allowed in the queue.""" + return self._maxsize + + def qsize(self): + """Number of items in the queue.""" + return len(self._queue) + + def empty(self): + return not self._queue + + def full(self): + if self.maxsize == 0: + return False + else: + return self.qsize() >= self.maxsize + + def put(self, item, timeout=None): + """Put an item into the queue, perhaps waiting until there is room. + + Returns a Future, which raises `tornado.gen.TimeoutError` after a + timeout. + """ + try: + self.put_nowait(item) + except QueueFull: + future = Future() + self._putters.append((item, future)) + _set_timeout(future, timeout) + return future + else: + return gen._null_future + + def put_nowait(self, item): + """Put an item into the queue without blocking. + + If no free slot is immediately available, raise `QueueFull`. + """ + self._consume_expired() + if self._getters: + assert self.empty(), "queue non-empty, why are getters waiting?" + getter = self._getters.popleft() + self.__put_internal(item) + getter.set_result(self._get()) + elif self.full(): + raise QueueFull + else: + self.__put_internal(item) + + def get(self, timeout=None): + """Remove and return an item from the queue. + + Returns a Future which resolves once an item is available, or raises + `tornado.gen.TimeoutError` after a timeout. + """ + future = Future() + try: + future.set_result(self.get_nowait()) + except QueueEmpty: + self._getters.append(future) + _set_timeout(future, timeout) + return future + + def get_nowait(self): + """Remove and return an item from the queue without blocking. + + Return an item if one is immediately available, else raise + `QueueEmpty`. + """ + self._consume_expired() + if self._putters: + assert self.full(), "queue not full, why are putters waiting?" + item, putter = self._putters.popleft() + self.__put_internal(item) + putter.set_result(None) + return self._get() + elif self.qsize(): + return self._get() + else: + raise QueueEmpty + + def task_done(self): + """Indicate that a formerly enqueued task is complete. + + Used by queue consumers. For each `.get` used to fetch a task, a + subsequent call to `.task_done` tells the queue that the processing + on the task is complete. + + If a `.join` is blocking, it resumes when all items have been + processed; that is, when every `.put` is matched by a `.task_done`. + + Raises `ValueError` if called more times than `.put`. + """ + if self._unfinished_tasks <= 0: + raise ValueError('task_done() called too many times') + self._unfinished_tasks -= 1 + if self._unfinished_tasks == 0: + self._finished.set() + + def join(self, timeout=None): + """Block until all items in the queue are processed. + + Returns a Future, which raises `tornado.gen.TimeoutError` after a + timeout. + """ + return self._finished.wait(timeout) + + # These three are overridable in subclasses. + def _init(self): + self._queue = collections.deque() + + def _get(self): + return self._queue.popleft() + + def _put(self, item): + self._queue.append(item) + # End of the overridable methods. + + def __put_internal(self, item): + self._unfinished_tasks += 1 + self._finished.clear() + self._put(item) + + def _consume_expired(self): + # Remove timed-out waiters. + while self._putters and self._putters[0][1].done(): + self._putters.popleft() + + while self._getters and self._getters[0].done(): + self._getters.popleft() + + def __repr__(self): + return '<%s at %s %s>' % ( + type(self).__name__, hex(id(self)), self._format()) + + def __str__(self): + return '<%s %s>' % (type(self).__name__, self._format()) + + def _format(self): + result = 'maxsize=%r' % (self.maxsize, ) + if getattr(self, '_queue', None): + result += ' queue=%r' % self._queue + if self._getters: + result += ' getters[%s]' % len(self._getters) + if self._putters: + result += ' putters[%s]' % len(self._putters) + if self._unfinished_tasks: + result += ' tasks=%s' % self._unfinished_tasks + return result + + +class PriorityQueue(Queue): + """A `.Queue` that retrieves entries in priority order, lowest first. + + Entries are typically tuples like ``(priority number, data)``. + + .. testcode:: + + q = queues.PriorityQueue() + q.put((1, 'medium-priority item')) + q.put((0, 'high-priority item')) + q.put((10, 'low-priority item')) + + print(q.get_nowait()) + print(q.get_nowait()) + print(q.get_nowait()) + + .. testoutput:: + + (0, 'high-priority item') + (1, 'medium-priority item') + (10, 'low-priority item') + """ + def _init(self): + self._queue = [] + + def _put(self, item): + heapq.heappush(self._queue, item) + + def _get(self): + return heapq.heappop(self._queue) + + +class LifoQueue(Queue): + """A `.Queue` that retrieves the most recently put items first. + + .. testcode:: + + q = queues.LifoQueue() + q.put(3) + q.put(2) + q.put(1) + + print(q.get_nowait()) + print(q.get_nowait()) + print(q.get_nowait()) + + .. testoutput:: + + 1 + 2 + 3 + """ + def _init(self): + self._queue = [] + + def _put(self, item): + self._queue.append(item) + + def _get(self): + return self._queue.pop() diff --git a/tornado/simple_httpclient.py b/tornado/simple_httpclient.py index 31d076e2d114d0840dcf586410aee11eebcaf43f..cf58e16263bc98fd1b9802845e295f15c17b440e 100644 --- a/tornado/simple_httpclient.py +++ b/tornado/simple_httpclient.py @@ -7,7 +7,7 @@ from tornado.httpclient import HTTPResponse, HTTPError, AsyncHTTPClient, main, _ from tornado import httputil from tornado.http1connection import HTTP1Connection, HTTP1ConnectionParameters from tornado.iostream import StreamClosedError -from tornado.netutil import Resolver, OverrideResolver +from tornado.netutil import Resolver, OverrideResolver, _client_ssl_defaults from tornado.log import gen_log from tornado import stack_context from tornado.tcpclient import TCPClient @@ -50,9 +50,6 @@ class SimpleAsyncHTTPClient(AsyncHTTPClient): """Non-blocking HTTP client with no external dependencies. This class implements an HTTP 1.1 client on top of Tornado's IOStreams. - It does not currently implement all applicable parts of the HTTP - specification, but it does enough to work with major web service APIs. - Some features found in the curl-based AsyncHTTPClient are not yet supported. In particular, proxies are not supported, connections are not reused, and callers cannot select the network interface to be @@ -60,25 +57,39 @@ class SimpleAsyncHTTPClient(AsyncHTTPClient): """ def initialize(self, io_loop, max_clients=10, hostname_mapping=None, max_buffer_size=104857600, - resolver=None, defaults=None, max_header_size=None): + resolver=None, defaults=None, max_header_size=None, + max_body_size=None): """Creates a AsyncHTTPClient. Only a single AsyncHTTPClient instance exists per IOLoop in order to provide limitations on the number of pending connections. - force_instance=True may be used to suppress this behavior. + ``force_instance=True`` may be used to suppress this behavior. + + Note that because of this implicit reuse, unless ``force_instance`` + is used, only the first call to the constructor actually uses + its arguments. It is recommended to use the ``configure`` method + instead of the constructor to ensure that arguments take effect. - max_clients is the number of concurrent requests that can be - in progress. Note that this arguments are only used when the - client is first created, and will be ignored when an existing - client is reused. + ``max_clients`` is the number of concurrent requests that can be + in progress; when this limit is reached additional requests will be + queued. Note that time spent waiting in this queue still counts + against the ``request_timeout``. - hostname_mapping is a dictionary mapping hostnames to IP addresses. + ``hostname_mapping`` is a dictionary mapping hostnames to IP addresses. It can be used to make local DNS changes when modifying system-wide - settings like /etc/hosts is not possible or desirable (e.g. in + settings like ``/etc/hosts`` is not possible or desirable (e.g. in unittests). - max_buffer_size is the number of bytes that can be read by IOStream. It - defaults to 100mb. + ``max_buffer_size`` (default 100MB) is the number of bytes + that can be read into memory at once. ``max_body_size`` + (defaults to ``max_buffer_size``) is the largest response body + that the client will accept. Without a + ``streaming_callback``, the smaller of these two limits + applies; with a ``streaming_callback`` only ``max_body_size`` + does. + + .. versionchanged:: 4.2 + Added the ``max_body_size`` argument. """ super(SimpleAsyncHTTPClient, self).initialize(io_loop, defaults=defaults) @@ -88,6 +99,7 @@ class SimpleAsyncHTTPClient(AsyncHTTPClient): self.waiting = {} self.max_buffer_size = max_buffer_size self.max_header_size = max_header_size + self.max_body_size = max_body_size # TCPClient could create a Resolver for us, but we have to do it # ourselves to support hostname_mapping. if resolver: @@ -135,10 +147,14 @@ class SimpleAsyncHTTPClient(AsyncHTTPClient): release_callback = functools.partial(self._release_fetch, key) self._handle_request(request, release_callback, callback) + def _connection_class(self): + return _HTTPConnection + def _handle_request(self, request, release_callback, final_callback): - _HTTPConnection(self.io_loop, self, request, release_callback, - final_callback, self.max_buffer_size, self.tcp_client, - self.max_header_size) + self._connection_class()( + self.io_loop, self, request, release_callback, + final_callback, self.max_buffer_size, self.tcp_client, + self.max_header_size, self.max_body_size) def _release_fetch(self, key): del self.active[key] @@ -166,7 +182,7 @@ class _HTTPConnection(httputil.HTTPMessageDelegate): def __init__(self, io_loop, client, request, release_callback, final_callback, max_buffer_size, tcp_client, - max_header_size): + max_header_size, max_body_size): self.start_time = io_loop.time() self.io_loop = io_loop self.client = client @@ -176,6 +192,7 @@ class _HTTPConnection(httputil.HTTPMessageDelegate): self.max_buffer_size = max_buffer_size self.tcp_client = tcp_client self.max_header_size = max_header_size + self.max_body_size = max_body_size self.code = None self.headers = None self.chunks = [] @@ -220,12 +237,24 @@ class _HTTPConnection(httputil.HTTPMessageDelegate): def _get_ssl_options(self, scheme): if scheme == "https": + if self.request.ssl_options is not None: + return self.request.ssl_options + # If we are using the defaults, don't construct a + # new SSLContext. + if (self.request.validate_cert and + self.request.ca_certs is None and + self.request.client_cert is None and + self.request.client_key is None): + return _client_ssl_defaults ssl_options = {} if self.request.validate_cert: ssl_options["cert_reqs"] = ssl.CERT_REQUIRED if self.request.ca_certs is not None: ssl_options["ca_certs"] = self.request.ca_certs - else: + elif not hasattr(ssl, 'create_default_context'): + # When create_default_context is present, + # we can omit the "ca_certs" parameter entirely, + # which avoids the dependency on "certifi" for py34. ssl_options["ca_certs"] = _default_ca_certs() if self.request.client_key is not None: ssl_options["keyfile"] = self.request.client_key @@ -317,9 +346,9 @@ class _HTTPConnection(httputil.HTTPMessageDelegate): body_present = (self.request.body is not None or self.request.body_producer is not None) if ((body_expected and not body_present) or - (body_present and not body_expected)): + (body_present and not body_expected)): raise ValueError( - 'Body must %sbe None for method %s (unelss ' + 'Body must %sbe None for method %s (unless ' 'allow_nonstandard_methods is true)' % ('not ' if body_expected else '', self.request.method)) if self.request.expect_100_continue: @@ -336,14 +365,7 @@ class _HTTPConnection(httputil.HTTPMessageDelegate): self.request.headers["Accept-Encoding"] = "gzip" req_path = ((self.parsed.path or '/') + (('?' + self.parsed.query) if self.parsed.query else '')) - self.stream.set_nodelay(True) - self.connection = HTTP1Connection( - self.stream, True, - HTTP1ConnectionParameters( - no_keep_alive=True, - max_header_size=self.max_header_size, - decompress=self.request.decompress_response), - self._sockaddr) + self.connection = self._create_connection(stream) start_line = httputil.RequestStartLine(self.request.method, req_path, '') self.connection.write_headers(start_line, self.request.headers) @@ -352,10 +374,21 @@ class _HTTPConnection(httputil.HTTPMessageDelegate): else: self._write_body(True) + def _create_connection(self, stream): + stream.set_nodelay(True) + connection = HTTP1Connection( + stream, True, + HTTP1ConnectionParameters( + no_keep_alive=True, + max_header_size=self.max_header_size, + max_body_size=self.max_body_size, + decompress=self.request.decompress_response), + self._sockaddr) + return connection + def _write_body(self, start_read): if self.request.body is not None: self.connection.write(self.request.body) - self.connection.finish() elif self.request.body_producer is not None: fut = self.request.body_producer(self.connection.write) if is_future(fut): @@ -366,7 +399,7 @@ class _HTTPConnection(httputil.HTTPMessageDelegate): self._read_response() self.io_loop.add_future(fut, on_body_written) return - self.connection.finish() + self.connection.finish() if start_read: self._read_response() diff --git a/tornado/tcpserver.py b/tornado/tcpserver.py index a02b36ffffda457172f7c568decea50c1829ea1a..c9d148a80e90d8424da38360dbee0455948411f2 100644 --- a/tornado/tcpserver.py +++ b/tornado/tcpserver.py @@ -41,14 +41,15 @@ class TCPServer(object): To use `TCPServer`, define a subclass which overrides the `handle_stream` method. - To make this server serve SSL traffic, send the ssl_options dictionary - argument with the arguments required for the `ssl.wrap_socket` method, - including "certfile" and "keyfile":: + To make this server serve SSL traffic, send the ``ssl_options`` keyword + argument with an `ssl.SSLContext` object. For compatibility with older + versions of Python ``ssl_options`` may also be a dictionary of keyword + arguments for the `ssl.wrap_socket` method.:: - TCPServer(ssl_options={ - "certfile": os.path.join(data_dir, "mydomain.crt"), - "keyfile": os.path.join(data_dir, "mydomain.key"), - }) + ssl_ctx = ssl.create_default_context(ssl.Purpose.CLIENT_AUTH) + ssl_ctx.load_cert_chain(os.path.join(data_dir, "mydomain.crt"), + os.path.join(data_dir, "mydomain.key")) + TCPServer(ssl_options=ssl_ctx) `TCPServer` initialization follows one of three patterns: @@ -56,14 +57,14 @@ class TCPServer(object): server = TCPServer() server.listen(8888) - IOLoop.instance().start() + IOLoop.current().start() 2. `bind`/`start`: simple multi-process:: server = TCPServer() server.bind(8888) server.start(0) # Forks multiple sub-processes - IOLoop.instance().start() + IOLoop.current().start() When using this interface, an `.IOLoop` must *not* be passed to the `TCPServer` constructor. `start` will always start @@ -75,7 +76,7 @@ class TCPServer(object): tornado.process.fork_processes(0) server = TCPServer() server.add_sockets(sockets) - IOLoop.instance().start() + IOLoop.current().start() The `add_sockets` interface is more complicated, but it can be used with `tornado.process.fork_processes` to give you more @@ -212,7 +213,20 @@ class TCPServer(object): sock.close() def handle_stream(self, stream, address): - """Override to handle a new `.IOStream` from an incoming connection.""" + """Override to handle a new `.IOStream` from an incoming connection. + + This method may be a coroutine; if so any exceptions it raises + asynchronously will be logged. Accepting of incoming connections + will not be blocked by this coroutine. + + If this `TCPServer` is configured for SSL, ``handle_stream`` + may be called before the SSL handshake has completed. Use + `.SSLIOStream.wait_for_handshake` if you need to verify the client's + certificate or use NPN/ALPN. + + .. versionchanged:: 4.2 + Added the option for this method to be a coroutine. + """ raise NotImplementedError() def _handle_connection(self, connection, address): @@ -252,6 +266,8 @@ class TCPServer(object): stream = IOStream(connection, io_loop=self.io_loop, max_buffer_size=self.max_buffer_size, read_chunk_size=self.read_chunk_size) - self.handle_stream(stream, address) + future = self.handle_stream(stream, address) + if future is not None: + self.io_loop.add_future(future, lambda f: f.result()) except Exception: app_log.error("Error in connection callback", exc_info=True) diff --git a/tornado/test/README b/tornado/test/README deleted file mode 100644 index 33edba9832c9296b43280b86eed09d0c43cc9abb..0000000000000000000000000000000000000000 --- a/tornado/test/README +++ /dev/null @@ -1,4 +0,0 @@ -Test coverage is almost non-existent, but it's a start. Be sure to -set PYTHONPATH appropriately (generally to the root directory of your -tornado checkout) when running tests to make sure you're getting the -version of the tornado package that you expect. \ No newline at end of file diff --git a/tornado/test/asyncio_test.py b/tornado/test/asyncio_test.py index cb990748f6f45173b982e9d7c0dce4ccbd35fb95..1be0e54f35529da7a0a6671992f0b279d74cffdb 100644 --- a/tornado/test/asyncio_test.py +++ b/tornado/test/asyncio_test.py @@ -27,6 +27,7 @@ except ImportError: skipIfNoSingleDispatch = unittest.skipIf( gen.singledispatch is None, "singledispatch module not present") + @unittest.skipIf(asyncio is None, "asyncio module not present") class AsyncIOLoopTest(AsyncTestCase): def get_new_ioloop(self): diff --git a/tornado/test/auth_test.py b/tornado/test/auth_test.py index 254e1ae13c6862c319f1d2b2f952c309111b2765..541ecf16f3888993d2a54f164a85fd5450bd71a1 100644 --- a/tornado/test/auth_test.py +++ b/tornado/test/auth_test.py @@ -5,7 +5,7 @@ from __future__ import absolute_import, division, print_function, with_statement -from tornado.auth import OpenIdMixin, OAuthMixin, OAuth2Mixin, TwitterMixin, GoogleMixin, AuthError +from tornado.auth import OpenIdMixin, OAuthMixin, OAuth2Mixin, TwitterMixin, AuthError from tornado.concurrent import Future from tornado.escape import json_decode from tornado import gen @@ -238,28 +238,6 @@ class TwitterServerVerifyCredentialsHandler(RequestHandler): self.write(dict(screen_name='foo', name='Foo')) -class GoogleOpenIdClientLoginHandler(RequestHandler, GoogleMixin): - def initialize(self, test): - self._OPENID_ENDPOINT = test.get_url('/openid/server/authenticate') - - @asynchronous - def get(self): - if self.get_argument("openid.mode", None): - self.get_authenticated_user(self.on_user) - return - res = self.authenticate_redirect() - assert isinstance(res, Future) - assert res.done() - - def on_user(self, user): - if user is None: - raise Exception("user is None") - self.finish(user) - - def get_auth_http_client(self): - return self.settings['http_client'] - - class AuthTest(AsyncHTTPTestCase): def get_app(self): return Application( @@ -286,7 +264,6 @@ class AuthTest(AsyncHTTPTestCase): ('/twitter/client/login_gen_coroutine', TwitterClientLoginGenCoroutineHandler, dict(test=self)), ('/twitter/client/show_user', TwitterClientShowUserHandler, dict(test=self)), ('/twitter/client/show_user_future', TwitterClientShowUserFutureHandler, dict(test=self)), - ('/google/client/openid_login', GoogleOpenIdClientLoginHandler, dict(test=self)), # simulated servers ('/openid/server/authenticate', OpenIdServerAuthenticateHandler), @@ -436,16 +413,3 @@ class AuthTest(AsyncHTTPTestCase): response = self.fetch('/twitter/client/show_user_future?name=error') self.assertEqual(response.code, 500) self.assertIn(b'Error response HTTP 500', response.body) - - def test_google_redirect(self): - # same as test_openid_redirect - response = self.fetch('/google/client/openid_login', follow_redirects=False) - self.assertEqual(response.code, 302) - self.assertTrue( - '/openid/server/authenticate?' in response.headers['Location']) - - def test_google_get_user(self): - response = self.fetch('/google/client/openid_login?openid.mode=blah&openid.ns.ax=http://openid.net/srv/ax/1.0&openid.ax.type.email=http://axschema.org/contact/email&openid.ax.value.email=foo@example.com', follow_redirects=False) - response.rethrow() - parsed = json_decode(response.body) - self.assertEqual(parsed["email"], "foo@example.com") diff --git a/tornado/test/concurrent_test.py b/tornado/test/concurrent_test.py index 5e93ad6a42ba5fd55a9f100a3fbd448085088e89..bf90ad0ec92cd829f0084941258dc239ebbc54c5 100644 --- a/tornado/test/concurrent_test.py +++ b/tornado/test/concurrent_test.py @@ -21,13 +21,14 @@ import socket import sys import traceback -from tornado.concurrent import Future, return_future, ReturnValueIgnoredError +from tornado.concurrent import Future, return_future, ReturnValueIgnoredError, run_on_executor from tornado.escape import utf8, to_unicode from tornado import gen from tornado.iostream import IOStream from tornado import stack_context from tornado.tcpserver import TCPServer from tornado.testing import AsyncTestCase, LogTrapTestCase, bind_unused_port, gen_test +from tornado.test.util import unittest try: @@ -334,3 +335,81 @@ class DecoratorClientTest(ClientTestMixin, AsyncTestCase, LogTrapTestCase): class GeneratorClientTest(ClientTestMixin, AsyncTestCase, LogTrapTestCase): client_class = GeneratorCapClient + + +@unittest.skipIf(futures is None, "concurrent.futures module not present") +class RunOnExecutorTest(AsyncTestCase): + @gen_test + def test_no_calling(self): + class Object(object): + def __init__(self, io_loop): + self.io_loop = io_loop + self.executor = futures.thread.ThreadPoolExecutor(1) + + @run_on_executor + def f(self): + return 42 + + o = Object(io_loop=self.io_loop) + answer = yield o.f() + self.assertEqual(answer, 42) + + @gen_test + def test_call_with_no_args(self): + class Object(object): + def __init__(self, io_loop): + self.io_loop = io_loop + self.executor = futures.thread.ThreadPoolExecutor(1) + + @run_on_executor() + def f(self): + return 42 + + o = Object(io_loop=self.io_loop) + answer = yield o.f() + self.assertEqual(answer, 42) + + @gen_test + def test_call_with_io_loop(self): + class Object(object): + def __init__(self, io_loop): + self._io_loop = io_loop + self.executor = futures.thread.ThreadPoolExecutor(1) + + @run_on_executor(io_loop='_io_loop') + def f(self): + return 42 + + o = Object(io_loop=self.io_loop) + answer = yield o.f() + self.assertEqual(answer, 42) + + @gen_test + def test_call_with_executor(self): + class Object(object): + def __init__(self, io_loop): + self.io_loop = io_loop + self.__executor = futures.thread.ThreadPoolExecutor(1) + + @run_on_executor(executor='_Object__executor') + def f(self): + return 42 + + o = Object(io_loop=self.io_loop) + answer = yield o.f() + self.assertEqual(answer, 42) + + @gen_test + def test_call_with_both(self): + class Object(object): + def __init__(self, io_loop): + self._io_loop = io_loop + self.__executor = futures.thread.ThreadPoolExecutor(1) + + @run_on_executor(io_loop='_io_loop', executor='_Object__executor') + def f(self): + return 42 + + o = Object(io_loop=self.io_loop) + answer = yield o.f() + self.assertEqual(answer, 42) diff --git a/tornado/test/curl_httpclient_test.py b/tornado/test/curl_httpclient_test.py index 8d7065dfe34287c62f7e9e3e34d7fb07bc61854d..3ac21f4d7260be0f4927695e61580a7c6c628118 100644 --- a/tornado/test/curl_httpclient_test.py +++ b/tornado/test/curl_httpclient_test.py @@ -8,7 +8,7 @@ from tornado.stack_context import ExceptionStackContext from tornado.testing import AsyncHTTPTestCase from tornado.test import httpclient_test from tornado.test.util import unittest -from tornado.web import Application, RequestHandler, URLSpec +from tornado.web import Application, RequestHandler try: diff --git a/tornado/test/escape_test.py b/tornado/test/escape_test.py index f640428881522e869875dcd6d079a1bf01b3e5a9..65765b68aa31fd90205d38fe22dcbf3b081ca381 100644 --- a/tornado/test/escape_test.py +++ b/tornado/test/escape_test.py @@ -154,6 +154,19 @@ class EscapeTestCase(unittest.TestCase): self.assertEqual(utf8(xhtml_escape(unescaped)), utf8(escaped)) self.assertEqual(utf8(unescaped), utf8(xhtml_unescape(escaped))) + def test_xhtml_unescape_numeric(self): + tests = [ + ('foo bar', 'foo bar'), + ('foo bar', 'foo bar'), + ('foo bar', 'foo bar'), + ('foo઼bar', u('foo\u0abcbar')), + ('foo&#xyz;bar', 'foo&#xyz;bar'), # invalid encoding + ('foo&#;bar', 'foo&#;bar'), # invalid encoding + ('foo&#x;bar', 'foo&#x;bar'), # invalid encoding + ] + for escaped, unescaped in tests: + self.assertEqual(unescaped, xhtml_unescape(escaped)) + def test_url_escape_unicode(self): tests = [ # byte strings are passed through as-is @@ -217,9 +230,8 @@ class EscapeTestCase(unittest.TestCase): self.assertRaises(UnicodeDecodeError, json_encode, b"\xe9") def test_squeeze(self): - self.assertEqual(squeeze(u('sequences of whitespace chars')) - , u('sequences of whitespace chars')) - + self.assertEqual(squeeze(u('sequences of whitespace chars')), u('sequences of whitespace chars')) + def test_recursive_unicode(self): tests = { 'dict': {b"foo": b"bar"}, diff --git a/tornado/test/gen_test.py b/tornado/test/gen_test.py index 5d646d15007a0b100d5f9fa0c00623ba48690372..fdaa0ec804dd56c1efc3a9c46837bf453bb61ea4 100644 --- a/tornado/test/gen_test.py +++ b/tornado/test/gen_test.py @@ -62,6 +62,11 @@ class GenEngineTest(AsyncTestCase): def async_future(self, result, callback): self.io_loop.add_callback(callback, result) + @gen.coroutine + def async_exception(self, e): + yield gen.moment + raise e + def test_no_yield(self): @gen.engine def f(): @@ -385,11 +390,56 @@ class GenEngineTest(AsyncTestCase): results = yield [self.async_future(1), self.async_future(2)] self.assertEqual(results, [1, 2]) + @gen_test + def test_multi_future_duplicate(self): + f = self.async_future(2) + results = yield [self.async_future(1), f, self.async_future(3), f] + self.assertEqual(results, [1, 2, 3, 2]) + @gen_test def test_multi_dict_future(self): results = yield dict(foo=self.async_future(1), bar=self.async_future(2)) self.assertEqual(results, dict(foo=1, bar=2)) + @gen_test + def test_multi_exceptions(self): + with ExpectLog(app_log, "Multiple exceptions in yield list"): + with self.assertRaises(RuntimeError) as cm: + yield gen.Multi([self.async_exception(RuntimeError("error 1")), + self.async_exception(RuntimeError("error 2"))]) + self.assertEqual(str(cm.exception), "error 1") + + # With only one exception, no error is logged. + with self.assertRaises(RuntimeError): + yield gen.Multi([self.async_exception(RuntimeError("error 1")), + self.async_future(2)]) + + # Exception logging may be explicitly quieted. + with self.assertRaises(RuntimeError): + yield gen.Multi([self.async_exception(RuntimeError("error 1")), + self.async_exception(RuntimeError("error 2"))], + quiet_exceptions=RuntimeError) + + @gen_test + def test_multi_future_exceptions(self): + with ExpectLog(app_log, "Multiple exceptions in yield list"): + with self.assertRaises(RuntimeError) as cm: + yield [self.async_exception(RuntimeError("error 1")), + self.async_exception(RuntimeError("error 2"))] + self.assertEqual(str(cm.exception), "error 1") + + # With only one exception, no error is logged. + with self.assertRaises(RuntimeError): + yield [self.async_exception(RuntimeError("error 1")), + self.async_future(2)] + + # Exception logging may be explicitly quieted. + with self.assertRaises(RuntimeError): + yield gen.multi_future( + [self.async_exception(RuntimeError("error 1")), + self.async_exception(RuntimeError("error 2"))], + quiet_exceptions=RuntimeError) + def test_arguments(self): @gen.engine def f(): @@ -816,6 +866,7 @@ class GenCoroutineTest(AsyncTestCase): @gen_test def test_moment(self): calls = [] + @gen.coroutine def f(name, yieldable): for i in range(5): @@ -843,6 +894,29 @@ class GenCoroutineTest(AsyncTestCase): yield gen.sleep(0.01) self.finished = True + @skipBefore33 + @gen_test + def test_py3_leak_exception_context(self): + class LeakedException(Exception): + pass + + @gen.coroutine + def inner(iteration): + raise LeakedException(iteration) + + try: + yield inner(1) + except LeakedException as e: + self.assertEqual(str(e), "1") + self.assertIsNone(e.__context__) + + try: + yield inner(2) + except LeakedException as e: + self.assertEqual(str(e), "2") + self.assertIsNone(e.__context__) + + self.finished = True class GenSequenceHandler(RequestHandler): @asynchronous @@ -1072,6 +1146,7 @@ class WithTimeoutTest(AsyncTestCase): yield gen.with_timeout(datetime.timedelta(seconds=3600), executor.submit(lambda: None)) + class WaitIteratorTest(AsyncTestCase): @gen_test def test_empty_iterator(self): @@ -1121,10 +1196,10 @@ class WaitIteratorTest(AsyncTestCase): while not dg.done(): dr = yield dg.next() if dg.current_index == "f1": - self.assertTrue(dg.current_future==f1 and dr==24, + self.assertTrue(dg.current_future == f1 and dr == 24, "WaitIterator dict status incorrect") elif dg.current_index == "f2": - self.assertTrue(dg.current_future==f2 and dr==42, + self.assertTrue(dg.current_future == f2 and dr == 42, "WaitIterator dict status incorrect") else: self.fail("got bad WaitIterator index {}".format( @@ -1145,7 +1220,7 @@ class WaitIteratorTest(AsyncTestCase): futures[3].set_result(84) if iteration < 8: - self.io_loop.add_callback(self.finish_coroutines, iteration+1, futures) + self.io_loop.add_callback(self.finish_coroutines, iteration + 1, futures) @gen_test def test_iterator(self): @@ -1174,5 +1249,15 @@ class WaitIteratorTest(AsyncTestCase): self.assertEqual(g.current_index, 3, 'wrong index') i += 1 + @gen_test + def test_no_ref(self): + # In this usage, there is no direct hard reference to the + # WaitIterator itself, only the Future it returns. Since + # WaitIterator uses weak references internally to improve GC + # performance, this used to cause problems. + yield gen.with_timeout(datetime.timedelta(seconds=0.1), + gen.WaitIterator(gen.sleep(0)).next()) + + if __name__ == '__main__': unittest.main() diff --git a/tornado/test/gettext_translations/extract_me.py b/tornado/test/gettext_translations/extract_me.py deleted file mode 100644 index 75406ecc77d70611d2a3986cdeb82c788063a7c2..0000000000000000000000000000000000000000 --- a/tornado/test/gettext_translations/extract_me.py +++ /dev/null @@ -1,11 +0,0 @@ -# Dummy source file to allow creation of the initial .po file in the -# same way as a real project. I'm not entirely sure about the real -# workflow here, but this seems to work. -# -# 1) xgettext --language=Python --keyword=_:1,2 -d tornado_test extract_me.py -o tornado_test.po -# 2) Edit tornado_test.po, setting CHARSET and setting msgstr -# 3) msgfmt tornado_test.po -o tornado_test.mo -# 4) Put the file in the proper location: $LANG/LC_MESSAGES - -from __future__ import absolute_import, division, print_function, with_statement -_("school") diff --git a/tornado/test/gettext_translations/fr_FR/LC_MESSAGES/tornado_test.mo b/tornado/test/gettext_translations/fr_FR/LC_MESSAGES/tornado_test.mo index 089f6c7ab79294f4441121c15ebaa10c3cea02f3..a97bf9c57460ecfc27761accf90d712ea5cebb44 100644 Binary files a/tornado/test/gettext_translations/fr_FR/LC_MESSAGES/tornado_test.mo and b/tornado/test/gettext_translations/fr_FR/LC_MESSAGES/tornado_test.mo differ diff --git a/tornado/test/gettext_translations/fr_FR/LC_MESSAGES/tornado_test.po b/tornado/test/gettext_translations/fr_FR/LC_MESSAGES/tornado_test.po index 732ee6da8e0e13e591a94651d7f70a18b676c753..88d72c8623a4275c85cb32e2ec35205b5b907176 100644 --- a/tornado/test/gettext_translations/fr_FR/LC_MESSAGES/tornado_test.po +++ b/tornado/test/gettext_translations/fr_FR/LC_MESSAGES/tornado_test.po @@ -8,7 +8,7 @@ msgid "" msgstr "" "Project-Id-Version: PACKAGE VERSION\n" "Report-Msgid-Bugs-To: \n" -"POT-Creation-Date: 2012-06-14 01:10-0700\n" +"POT-Creation-Date: 2015-01-27 11:05+0300\n" "PO-Revision-Date: YEAR-MO-DA HO:MI+ZONE\n" "Last-Translator: FULL NAME <EMAIL@ADDRESS>\n" "Language-Team: LANGUAGE <LL@li.org>\n" @@ -16,7 +16,32 @@ msgstr "" "MIME-Version: 1.0\n" "Content-Type: text/plain; charset=utf-8\n" "Content-Transfer-Encoding: 8bit\n" +"Plural-Forms: nplurals=2; plural=(n > 1);\n" -#: extract_me.py:1 +#: extract_me.py:11 msgid "school" msgstr "école" + +#: extract_me.py:12 +msgctxt "law" +msgid "right" +msgstr "le droit" + +#: extract_me.py:13 +msgctxt "good" +msgid "right" +msgstr "le bien" + +#: extract_me.py:14 +msgctxt "organization" +msgid "club" +msgid_plural "clubs" +msgstr[0] "le club" +msgstr[1] "les clubs" + +#: extract_me.py:15 +msgctxt "stick" +msgid "club" +msgid_plural "clubs" +msgstr[0] "le bâton" +msgstr[1] "les bâtons" diff --git a/tornado/test/httpclient_test.py b/tornado/test/httpclient_test.py index 875864ac69ff66ab851a8b95819ba6f62e1fc607..ecc63e4a49e50f1a685c1ad247978a089c2de8e0 100644 --- a/tornado/test/httpclient_test.py +++ b/tornado/test/httpclient_test.py @@ -12,6 +12,7 @@ import datetime from io import BytesIO from tornado.escape import utf8 +from tornado import gen from tornado.httpclient import HTTPRequest, HTTPResponse, _RequestProxy, HTTPError, HTTPClient from tornado.httpserver import HTTPServer from tornado.ioloop import IOLoop @@ -52,9 +53,12 @@ class RedirectHandler(RequestHandler): class ChunkHandler(RequestHandler): + @gen.coroutine def get(self): self.write("asdf") self.flush() + # Wait a bit to ensure the chunks are sent and received separately. + yield gen.sleep(0.01) self.write("qwer") @@ -178,6 +182,8 @@ class HTTPClientCommonTestCase(AsyncHTTPTestCase): sock, port = bind_unused_port() with closing(sock): def write_response(stream, request_data): + if b"HTTP/1." not in request_data: + self.skipTest("requires HTTP/1.x") stream.write(b"""\ HTTP/1.1 200 OK Transfer-Encoding: chunked @@ -300,23 +306,26 @@ Transfer-Encoding: chunked chunks = [] def header_callback(header_line): - if header_line.startswith('HTTP/'): + if header_line.startswith('HTTP/1.1 101'): + # Upgrading to HTTP/2 + pass + elif header_line.startswith('HTTP/'): first_line.append(header_line) elif header_line != '\r\n': k, v = header_line.split(':', 1) - headers[k] = v.strip() + headers[k.lower()] = v.strip() def streaming_callback(chunk): # All header callbacks are run before any streaming callbacks, # so the header data is available to process the data as it # comes in. - self.assertEqual(headers['Content-Type'], 'text/html; charset=UTF-8') + self.assertEqual(headers['content-type'], 'text/html; charset=UTF-8') chunks.append(chunk) self.fetch('/chunk', header_callback=header_callback, streaming_callback=streaming_callback) - self.assertEqual(len(first_line), 1) - self.assertRegexpMatches(first_line[0], 'HTTP/1.[01] 200 OK\r\n') + self.assertEqual(len(first_line), 1, first_line) + self.assertRegexpMatches(first_line[0], 'HTTP/[0-9]\\.[0-9] 200.*\r\n') self.assertEqual(chunks, [b'asdf', b'qwer']) def test_header_callback_stack_context(self): @@ -327,7 +336,7 @@ Transfer-Encoding: chunked return True def header_callback(header_line): - if header_line.startswith('Content-Type:'): + if header_line.lower().startswith('content-type:'): 1 / 0 with ExceptionStackContext(error_handler): @@ -459,7 +468,7 @@ Transfer-Encoding: chunked # Twisted's reactor does not. The removeReader call fails and so # do all future removeAll calls (which our tests do at cleanup). # - #def test_post_307(self): + # def test_post_307(self): # response = self.fetch("/redirect?status=307&url=/post", # method="POST", body=b"arg1=foo&arg2=bar") # self.assertEqual(response.body, b"Post arg1: foo, arg2: bar") @@ -541,7 +550,12 @@ class SyncHTTPClientTest(unittest.TestCase): def tearDown(self): def stop_server(): self.server.stop() - self.server_ioloop.stop() + # Delay the shutdown of the IOLoop by one iteration because + # the server may still have some cleanup work left when + # the client finishes with the response (this is noticable + # with http/2, which leaves a Future with an unexamined + # StreamClosedError on the loop). + self.server_ioloop.add_callback(self.server_ioloop.stop) self.server_ioloop.add_callback(stop_server) self.server_thread.join() self.http_client.close() @@ -589,5 +603,5 @@ class HTTPRequestTestCase(unittest.TestCase): def test_if_modified_since(self): http_date = datetime.datetime.utcnow() request = HTTPRequest('http://example.com', if_modified_since=http_date) - self.assertEqual(request.headers, - {'If-Modified-Since': format_timestamp(http_date)}) + self.assertEqual(request.headers, + {'If-Modified-Since': format_timestamp(http_date)}) diff --git a/tornado/test/httpserver_test.py b/tornado/test/httpserver_test.py index 64ef96d459406d7224bbe1e968da62824ffc6291..f05599dd12fe508a281f65daa0dbeb19ec4f4fb9 100644 --- a/tornado/test/httpserver_test.py +++ b/tornado/test/httpserver_test.py @@ -32,6 +32,7 @@ def read_stream_body(stream, callback): """Reads an HTTP response from `stream` and runs callback with its headers and body.""" chunks = [] + class Delegate(HTTPMessageDelegate): def headers_received(self, start_line, headers): self.headers = headers @@ -161,19 +162,22 @@ class BadSSLOptionsTest(unittest.TestCase): application = Application() module_dir = os.path.dirname(__file__) existing_certificate = os.path.join(module_dir, 'test.crt') + existing_key = os.path.join(module_dir, 'test.key') - self.assertRaises(ValueError, HTTPServer, application, ssl_options={ - "certfile": "/__mising__.crt", + self.assertRaises((ValueError, IOError), + HTTPServer, application, ssl_options={ + "certfile": "/__mising__.crt", }) - self.assertRaises(ValueError, HTTPServer, application, ssl_options={ - "certfile": existing_certificate, - "keyfile": "/__missing__.key" + self.assertRaises((ValueError, IOError), + HTTPServer, application, ssl_options={ + "certfile": existing_certificate, + "keyfile": "/__missing__.key" }) # This actually works because both files exist HTTPServer(application, ssl_options={ "certfile": existing_certificate, - "keyfile": existing_certificate + "keyfile": existing_key, }) @@ -589,6 +593,7 @@ class KeepAliveTest(AsyncHTTPTestCase): class HelloHandler(RequestHandler): def get(self): self.finish('Hello world') + def post(self): self.finish('Hello world') @@ -863,6 +868,7 @@ class StreamingChunkSizeTest(AsyncHTTPTestCase): def test_chunked_compressed(self): compressed = self.compress(self.BODY) self.assertGreater(len(compressed), 20) + def body_producer(write): write(compressed[:20]) write(compressed[20:]) @@ -1052,6 +1058,15 @@ class LegacyInterfaceTest(AsyncHTTPTestCase): # delegate interface, and writes its response via request.write # instead of request.connection.write_headers. def handle_request(request): + self.http1 = request.version.startswith("HTTP/1.") + if not self.http1: + # This test will be skipped if we're using HTTP/2, + # so just close it out cleanly using the modern interface. + request.connection.write_headers( + ResponseStartLine('', 200, 'OK'), + HTTPHeaders()) + request.connection.finish() + return message = b"Hello world" request.write(utf8("HTTP/1.1 200 OK\r\n" "Content-Length: %d\r\n\r\n" % len(message))) @@ -1061,4 +1076,6 @@ class LegacyInterfaceTest(AsyncHTTPTestCase): def test_legacy_interface(self): response = self.fetch('/') + if not self.http1: + self.skipTest("requires HTTP/1.x") self.assertEqual(response.body, b"Hello world") diff --git a/tornado/test/httputil_test.py b/tornado/test/httputil_test.py index 3995abe8b94ab9166e0ed43205b12cf4ef578e3d..6e95360174de84b1e2f5f45f27000fcfe80f2dbc 100644 --- a/tornado/test/httputil_test.py +++ b/tornado/test/httputil_test.py @@ -9,6 +9,7 @@ from tornado.testing import ExpectLog from tornado.test.util import unittest from tornado.util import u +import copy import datetime import logging import time @@ -237,14 +238,14 @@ Foo: even # and cpython's unicodeobject.c (which defines the implementation # of unicode_type.splitlines(), and uses a different list than TR13). newlines = [ - u('\u001b'), # VERTICAL TAB - u('\u001c'), # FILE SEPARATOR - u('\u001d'), # GROUP SEPARATOR - u('\u001e'), # RECORD SEPARATOR - u('\u0085'), # NEXT LINE - u('\u2028'), # LINE SEPARATOR - u('\u2029'), # PARAGRAPH SEPARATOR - ] + u('\u001b'), # VERTICAL TAB + u('\u001c'), # FILE SEPARATOR + u('\u001d'), # GROUP SEPARATOR + u('\u001e'), # RECORD SEPARATOR + u('\u0085'), # NEXT LINE + u('\u2028'), # LINE SEPARATOR + u('\u2029'), # PARAGRAPH SEPARATOR + ] for newline in newlines: # Try the utf8 and latin1 representations of each newline for encoding in ['utf8', 'latin1']: @@ -278,7 +279,25 @@ Foo: even [('Cr', 'cr\rMore: more'), ('Crlf', 'crlf'), ('Lf', 'lf'), - ]) + ]) + + def test_copy(self): + all_pairs = [('A', '1'), ('A', '2'), ('B', 'c')] + h1 = HTTPHeaders() + for k, v in all_pairs: + h1.add(k, v) + h2 = h1.copy() + h3 = copy.copy(h1) + h4 = copy.deepcopy(h1) + for headers in [h1, h2, h3, h4]: + # All the copies are identical, no matter how they were + # constructed. + self.assertEqual(list(sorted(headers.get_all())), all_pairs) + for headers in [h2, h3, h4]: + # Neither the dict or its member lists are reused. + self.assertIsNot(headers, h1) + self.assertIsNot(headers.get_list('A'), h1.get_list('A')) + class FormatTimestampTest(unittest.TestCase): diff --git a/tornado/test/import_test.py b/tornado/test/import_test.py index de7cc0b9fbf51e7cf464944107cc77fc9f77f619..1be6427f19aa7d52d814c0fc045d4f057e011855 100644 --- a/tornado/test/import_test.py +++ b/tornado/test/import_test.py @@ -1,3 +1,4 @@ +# flake8: noqa from __future__ import absolute_import, division, print_function, with_statement from tornado.test.util import unittest diff --git a/tornado/test/ioloop_test.py b/tornado/test/ioloop_test.py index 7eb7594fd32448c27bca7ab0c29ca979177be827..f3a0cbdcfe7313bb1057886d8acf01238c821b3e 100644 --- a/tornado/test/ioloop_test.py +++ b/tornado/test/ioloop_test.py @@ -11,8 +11,9 @@ import threading import time from tornado import gen -from tornado.ioloop import IOLoop, TimeoutError +from tornado.ioloop import IOLoop, TimeoutError, PollIOLoop, PeriodicCallback from tornado.log import app_log +from tornado.platform.select import _Select from tornado.stack_context import ExceptionStackContext, StackContext, wrap, NullContext from tornado.testing import AsyncTestCase, bind_unused_port, ExpectLog from tornado.test.util import unittest, skipIfNonUnix, skipOnTravis @@ -23,6 +24,42 @@ except ImportError: futures = None +class FakeTimeSelect(_Select): + def __init__(self): + self._time = 1000 + super(FakeTimeSelect, self).__init__() + + def time(self): + return self._time + + def sleep(self, t): + self._time += t + + def poll(self, timeout): + events = super(FakeTimeSelect, self).poll(0) + if events: + return events + self._time += timeout + return [] + + +class FakeTimeIOLoop(PollIOLoop): + """IOLoop implementation with a fake and deterministic clock. + + The clock advances as needed to trigger timeouts immediately. + For use when testing code that involves the passage of time + and no external dependencies. + """ + def initialize(self): + self.fts = FakeTimeSelect() + super(FakeTimeIOLoop, self).initialize(impl=self.fts, + time_func=self.fts.time) + + def sleep(self, t): + """Simulate a blocking sleep by advancing the clock.""" + self.fts.sleep(t) + + class TestIOLoop(AsyncTestCase): @skipOnTravis def test_add_callback_wakeup(self): @@ -180,10 +217,12 @@ class TestIOLoop(AsyncTestCase): # t2 should be cancelled by t1, even though it is already scheduled to # be run before the ioloop even looks at it. now = self.io_loop.time() + def t1(): calls[0] = True self.io_loop.remove_timeout(t2_handle) self.io_loop.add_timeout(now + 0.01, t1) + def t2(): calls[1] = True t2_handle = self.io_loop.add_timeout(now + 0.02, t2) @@ -252,6 +291,7 @@ class TestIOLoop(AsyncTestCase): """The handler callback receives the same fd object it passed in.""" server_sock, port = bind_unused_port() fds = [] + def handle_connection(fd, events): fds.append(fd) conn, addr = server_sock.accept() @@ -274,6 +314,7 @@ class TestIOLoop(AsyncTestCase): def test_mixed_fd_fileobj(self): server_sock, port = bind_unused_port() + def f(fd, events): pass self.io_loop.add_handler(server_sock, f, IOLoop.READ) @@ -288,6 +329,7 @@ class TestIOLoop(AsyncTestCase): """Calling start() twice should raise an error, not deadlock.""" returned_from_start = [False] got_exception = [False] + def callback(): try: self.io_loop.start() @@ -305,7 +347,7 @@ class TestIOLoop(AsyncTestCase): # Use a NullContext to keep the exception from being caught by # AsyncTestCase. with NullContext(): - self.io_loop.add_callback(lambda: 1/0) + self.io_loop.add_callback(lambda: 1 / 0) self.io_loop.add_callback(self.stop) with ExpectLog(app_log, "Exception in callback"): self.wait() @@ -316,7 +358,7 @@ class TestIOLoop(AsyncTestCase): @gen.coroutine def callback(): self.io_loop.add_callback(self.stop) - 1/0 + 1 / 0 self.io_loop.add_callback(callback) with ExpectLog(app_log, "Exception in callback"): self.wait() @@ -324,12 +366,12 @@ class TestIOLoop(AsyncTestCase): def test_spawn_callback(self): # An added callback runs in the test's stack_context, so will be # re-arised in wait(). - self.io_loop.add_callback(lambda: 1/0) + self.io_loop.add_callback(lambda: 1 / 0) with self.assertRaises(ZeroDivisionError): self.wait() # A spawned callback is run directly on the IOLoop, so it will be # logged without stopping the test. - self.io_loop.spawn_callback(lambda: 1/0) + self.io_loop.spawn_callback(lambda: 1 / 0) self.io_loop.add_callback(self.stop) with ExpectLog(app_log, "Exception in callback"): self.wait() @@ -344,6 +386,7 @@ class TestIOLoop(AsyncTestCase): # After reading from one fd, remove the other from the IOLoop. chunks = [] + def handle_read(fd, events): chunks.append(fd.recv(1024)) if fd is client: @@ -352,7 +395,7 @@ class TestIOLoop(AsyncTestCase): self.io_loop.remove_handler(client) self.io_loop.add_handler(client, handle_read, self.io_loop.READ) self.io_loop.add_handler(server, handle_read, self.io_loop.READ) - self.io_loop.call_later(0.01, self.stop) + self.io_loop.call_later(0.03, self.stop) self.wait() # Only one fd was read; the other was cleanly removed. @@ -520,5 +563,47 @@ class TestIOLoopRunSync(unittest.TestCase): self.assertRaises(TimeoutError, self.io_loop.run_sync, f, timeout=0.01) +class TestPeriodicCallback(unittest.TestCase): + def setUp(self): + self.io_loop = FakeTimeIOLoop() + self.io_loop.make_current() + + def tearDown(self): + self.io_loop.close() + + def test_basic(self): + calls = [] + + def cb(): + calls.append(self.io_loop.time()) + pc = PeriodicCallback(cb, 10000) + pc.start() + self.io_loop.call_later(50, self.io_loop.stop) + self.io_loop.start() + self.assertEqual(calls, [1010, 1020, 1030, 1040, 1050]) + + def test_overrun(self): + sleep_durations = [9, 9, 10, 11, 20, 20, 35, 35, 0, 0] + expected = [ + 1010, 1020, 1030, # first 3 calls on schedule + 1050, 1070, # next 2 delayed one cycle + 1100, 1130, # next 2 delayed 2 cycles + 1170, 1210, # next 2 delayed 3 cycles + 1220, 1230, # then back on schedule. + ] + calls = [] + + def cb(): + calls.append(self.io_loop.time()) + if not sleep_durations: + self.io_loop.stop() + return + self.io_loop.sleep(sleep_durations.pop(0)) + pc = PeriodicCallback(cb, 10000) + pc.start() + self.io_loop.start() + self.assertEqual(calls, expected) + + if __name__ == "__main__": unittest.main() diff --git a/tornado/test/iostream_test.py b/tornado/test/iostream_test.py index ca35de69bbae0b48467642c4b157e17f3878cab1..45df6b50a72cd6a8312e02bf52b6c46f5ed4c9c8 100644 --- a/tornado/test/iostream_test.py +++ b/tornado/test/iostream_test.py @@ -7,10 +7,10 @@ from tornado.httputil import HTTPHeaders from tornado.log import gen_log, app_log from tornado.netutil import ssl_wrap_socket from tornado.stack_context import NullContext +from tornado.tcpserver import TCPServer from tornado.testing import AsyncHTTPTestCase, AsyncHTTPSTestCase, AsyncTestCase, bind_unused_port, ExpectLog, gen_test from tornado.test.util import unittest, skipIfNonUnix, refusing_port from tornado.web import RequestHandler, Application -import certifi import errno import logging import os @@ -310,6 +310,7 @@ class TestIOStreamMixin(object): def streaming_callback(data): chunks.append(data) self.stop() + def close_callback(data): assert not data, data closed[0] = True @@ -327,6 +328,31 @@ class TestIOStreamMixin(object): server.close() client.close() + def test_streaming_until_close_future(self): + server, client = self.make_iostream_pair() + try: + chunks = [] + + @gen.coroutine + def client_task(): + yield client.read_until_close(streaming_callback=chunks.append) + + @gen.coroutine + def server_task(): + yield server.write(b"1234") + yield gen.sleep(0.01) + yield server.write(b"5678") + server.close() + + @gen.coroutine + def f(): + yield [client_task(), server_task()] + self.io_loop.run_sync(f) + self.assertEqual(chunks, [b"1234", b"5678"]) + finally: + server.close() + client.close() + def test_delayed_close_callback(self): # The scenario: Server closes the connection while there is a pending # read that can be served out of buffered data. The client does not @@ -355,6 +381,7 @@ class TestIOStreamMixin(object): def test_future_delayed_close_callback(self): # Same as test_delayed_close_callback, but with the future interface. server, client = self.make_iostream_pair() + # We can't call make_iostream_pair inside a gen_test function # because the ioloop is not reentrant. @gen_test @@ -534,6 +561,7 @@ class TestIOStreamMixin(object): # and IOStream._maybe_add_error_listener. server, client = self.make_iostream_pair() closed = [False] + def close_callback(): closed[0] = True self.stop() @@ -754,7 +782,8 @@ class TestIOStreamWebHTTP(TestIOStreamWebMixin, AsyncHTTPTestCase): class TestIOStreamWebHTTPS(TestIOStreamWebMixin, AsyncHTTPSTestCase): def _make_client_iostream(self): - return SSLIOStream(socket.socket(), io_loop=self.io_loop) + return SSLIOStream(socket.socket(), io_loop=self.io_loop, + ssl_options=dict(cert_reqs=ssl.CERT_NONE)) class TestIOStream(TestIOStreamMixin, AsyncTestCase): @@ -774,7 +803,9 @@ class TestIOStreamSSL(TestIOStreamMixin, AsyncTestCase): return SSLIOStream(connection, io_loop=self.io_loop, **kwargs) def _make_client_iostream(self, connection, **kwargs): - return SSLIOStream(connection, io_loop=self.io_loop, **kwargs) + return SSLIOStream(connection, io_loop=self.io_loop, + ssl_options=dict(cert_reqs=ssl.CERT_NONE), + **kwargs) # This will run some tests that are basically redundant but it's the @@ -864,7 +895,7 @@ class TestIOStreamStartTLS(AsyncTestCase): yield self.server_send_line(b"250 STARTTLS\r\n") yield self.client_send_line(b"STARTTLS\r\n") yield self.server_send_line(b"220 Go ahead\r\n") - client_future = self.client_start_tls() + client_future = self.client_start_tls(dict(cert_reqs=ssl.CERT_NONE)) server_future = self.server_start_tls(_server_ssl_options()) self.client_stream = yield client_future self.server_stream = yield server_future @@ -876,15 +907,14 @@ class TestIOStreamStartTLS(AsyncTestCase): @gen_test def test_handshake_fail(self): server_future = self.server_start_tls(_server_ssl_options()) - client_future = self.client_start_tls( - dict(cert_reqs=ssl.CERT_REQUIRED, ca_certs=certifi.where())) + # Certificates are verified with the default configuration. + client_future = self.client_start_tls(server_hostname="localhost") with ExpectLog(gen_log, "SSL Error"): with self.assertRaises(ssl.SSLError): yield client_future with self.assertRaises((ssl.SSLError, socket.error)): yield server_future - @unittest.skipIf(not hasattr(ssl, 'create_default_context'), 'ssl.create_default_context not present') @gen_test @@ -903,6 +933,98 @@ class TestIOStreamStartTLS(AsyncTestCase): yield server_future +class WaitForHandshakeTest(AsyncTestCase): + @gen.coroutine + def connect_to_server(self, server_cls): + server = client = None + try: + sock, port = bind_unused_port() + server = server_cls(ssl_options=_server_ssl_options()) + server.add_socket(sock) + + client = SSLIOStream(socket.socket(), + ssl_options=dict(cert_reqs=ssl.CERT_NONE)) + yield client.connect(('127.0.0.1', port)) + self.assertIsNotNone(client.socket.cipher()) + finally: + if server is not None: + server.stop() + if client is not None: + client.close() + + @gen_test + def test_wait_for_handshake_callback(self): + test = self + handshake_future = Future() + + class TestServer(TCPServer): + def handle_stream(self, stream, address): + # The handshake has not yet completed. + test.assertIsNone(stream.socket.cipher()) + self.stream = stream + stream.wait_for_handshake(self.handshake_done) + + def handshake_done(self): + # Now the handshake is done and ssl information is available. + test.assertIsNotNone(self.stream.socket.cipher()) + handshake_future.set_result(None) + + yield self.connect_to_server(TestServer) + yield handshake_future + + @gen_test + def test_wait_for_handshake_future(self): + test = self + handshake_future = Future() + + class TestServer(TCPServer): + def handle_stream(self, stream, address): + test.assertIsNone(stream.socket.cipher()) + test.io_loop.spawn_callback(self.handle_connection, stream) + + @gen.coroutine + def handle_connection(self, stream): + yield stream.wait_for_handshake() + handshake_future.set_result(None) + + yield self.connect_to_server(TestServer) + yield handshake_future + + @gen_test + def test_wait_for_handshake_already_waiting_error(self): + test = self + handshake_future = Future() + + class TestServer(TCPServer): + def handle_stream(self, stream, address): + stream.wait_for_handshake(self.handshake_done) + test.assertRaises(RuntimeError, stream.wait_for_handshake) + + def handshake_done(self): + handshake_future.set_result(None) + + yield self.connect_to_server(TestServer) + yield handshake_future + + @gen_test + def test_wait_for_handshake_already_connected(self): + handshake_future = Future() + + class TestServer(TCPServer): + def handle_stream(self, stream, address): + self.stream = stream + stream.wait_for_handshake(self.handshake_done) + + def handshake_done(self): + self.stream.wait_for_handshake(self.handshake2_done) + + def handshake2_done(self): + handshake_future.set_result(None) + + yield self.connect_to_server(TestServer) + yield handshake_future + + @skipIfNonUnix class TestPipeIOStream(AsyncTestCase): def test_pipe_iostream(self): diff --git a/tornado/test/locale_test.py b/tornado/test/locale_test.py index d12ad52ffa810e7d809b53c71bb94a399eeba4e3..31c57a6194c9632931b6435efc9267d149918b3c 100644 --- a/tornado/test/locale_test.py +++ b/tornado/test/locale_test.py @@ -41,6 +41,12 @@ class TranslationLoaderTest(unittest.TestCase): locale = tornado.locale.get("fr_FR") self.assertTrue(isinstance(locale, tornado.locale.GettextLocale)) self.assertEqual(locale.translate("school"), u("\u00e9cole")) + self.assertEqual(locale.pgettext("law", "right"), u("le droit")) + self.assertEqual(locale.pgettext("good", "right"), u("le bien")) + self.assertEqual(locale.pgettext("organization", "club", "clubs", 1), u("le club")) + self.assertEqual(locale.pgettext("organization", "club", "clubs", 2), u("les clubs")) + self.assertEqual(locale.pgettext("stick", "club", "clubs", 1), u("le b\xe2ton")) + self.assertEqual(locale.pgettext("stick", "club", "clubs", 2), u("les b\xe2tons")) class LocaleDataTest(unittest.TestCase): @@ -58,7 +64,7 @@ class EnglishTest(unittest.TestCase): self.assertEqual(locale.format_date(date, full_format=True), 'April 28, 2013 at 6:35 pm') - self.assertEqual(locale.format_date(datetime.datetime.utcnow() - datetime.timedelta(seconds=2), full_format=False), + self.assertEqual(locale.format_date(datetime.datetime.utcnow() - datetime.timedelta(seconds=2), full_format=False), '2 seconds ago') self.assertEqual(locale.format_date(datetime.datetime.utcnow() - datetime.timedelta(minutes=2), full_format=False), '2 minutes ago') diff --git a/tornado/test/locks_test.py b/tornado/test/locks_test.py new file mode 100644 index 0000000000000000000000000000000000000000..90bdafaa6020a64057b8299eace09785a13dfb05 --- /dev/null +++ b/tornado/test/locks_test.py @@ -0,0 +1,480 @@ +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +from datetime import timedelta + +from tornado import gen, locks +from tornado.gen import TimeoutError +from tornado.testing import gen_test, AsyncTestCase +from tornado.test.util import unittest + + +class ConditionTest(AsyncTestCase): + def setUp(self): + super(ConditionTest, self).setUp() + self.history = [] + + def record_done(self, future, key): + """Record the resolution of a Future returned by Condition.wait.""" + def callback(_): + if not future.result(): + # wait() resolved to False, meaning it timed out. + self.history.append('timeout') + else: + self.history.append(key) + future.add_done_callback(callback) + + def test_repr(self): + c = locks.Condition() + self.assertIn('Condition', repr(c)) + self.assertNotIn('waiters', repr(c)) + c.wait() + self.assertIn('waiters', repr(c)) + + @gen_test + def test_notify(self): + c = locks.Condition() + self.io_loop.call_later(0.01, c.notify) + yield c.wait() + + def test_notify_1(self): + c = locks.Condition() + self.record_done(c.wait(), 'wait1') + self.record_done(c.wait(), 'wait2') + c.notify(1) + self.history.append('notify1') + c.notify(1) + self.history.append('notify2') + self.assertEqual(['wait1', 'notify1', 'wait2', 'notify2'], + self.history) + + def test_notify_n(self): + c = locks.Condition() + for i in range(6): + self.record_done(c.wait(), i) + + c.notify(3) + + # Callbacks execute in the order they were registered. + self.assertEqual(list(range(3)), self.history) + c.notify(1) + self.assertEqual(list(range(4)), self.history) + c.notify(2) + self.assertEqual(list(range(6)), self.history) + + def test_notify_all(self): + c = locks.Condition() + for i in range(4): + self.record_done(c.wait(), i) + + c.notify_all() + self.history.append('notify_all') + + # Callbacks execute in the order they were registered. + self.assertEqual( + list(range(4)) + ['notify_all'], + self.history) + + @gen_test + def test_wait_timeout(self): + c = locks.Condition() + wait = c.wait(timedelta(seconds=0.01)) + self.io_loop.call_later(0.02, c.notify) # Too late. + yield gen.sleep(0.03) + self.assertFalse((yield wait)) + + @gen_test + def test_wait_timeout_preempted(self): + c = locks.Condition() + + # This fires before the wait times out. + self.io_loop.call_later(0.01, c.notify) + wait = c.wait(timedelta(seconds=0.02)) + yield gen.sleep(0.03) + yield wait # No TimeoutError. + + @gen_test + def test_notify_n_with_timeout(self): + # Register callbacks 0, 1, 2, and 3. Callback 1 has a timeout. + # Wait for that timeout to expire, then do notify(2) and make + # sure everyone runs. Verifies that a timed-out callback does + # not count against the 'n' argument to notify(). + c = locks.Condition() + self.record_done(c.wait(), 0) + self.record_done(c.wait(timedelta(seconds=0.01)), 1) + self.record_done(c.wait(), 2) + self.record_done(c.wait(), 3) + + # Wait for callback 1 to time out. + yield gen.sleep(0.02) + self.assertEqual(['timeout'], self.history) + + c.notify(2) + yield gen.sleep(0.01) + self.assertEqual(['timeout', 0, 2], self.history) + self.assertEqual(['timeout', 0, 2], self.history) + c.notify() + self.assertEqual(['timeout', 0, 2, 3], self.history) + + @gen_test + def test_notify_all_with_timeout(self): + c = locks.Condition() + self.record_done(c.wait(), 0) + self.record_done(c.wait(timedelta(seconds=0.01)), 1) + self.record_done(c.wait(), 2) + + # Wait for callback 1 to time out. + yield gen.sleep(0.02) + self.assertEqual(['timeout'], self.history) + + c.notify_all() + self.assertEqual(['timeout', 0, 2], self.history) + + @gen_test + def test_nested_notify(self): + # Ensure no notifications lost, even if notify() is reentered by a + # waiter calling notify(). + c = locks.Condition() + + # Three waiters. + futures = [c.wait() for _ in range(3)] + + # First and second futures resolved. Second future reenters notify(), + # resolving third future. + futures[1].add_done_callback(lambda _: c.notify()) + c.notify(2) + self.assertTrue(all(f.done() for f in futures)) + + @gen_test + def test_garbage_collection(self): + # Test that timed-out waiters are occasionally cleaned from the queue. + c = locks.Condition() + for _ in range(101): + c.wait(timedelta(seconds=0.01)) + + future = c.wait() + self.assertEqual(102, len(c._waiters)) + + # Let first 101 waiters time out, triggering a collection. + yield gen.sleep(0.02) + self.assertEqual(1, len(c._waiters)) + + # Final waiter is still active. + self.assertFalse(future.done()) + c.notify() + self.assertTrue(future.done()) + + +class EventTest(AsyncTestCase): + def test_repr(self): + event = locks.Event() + self.assertTrue('clear' in str(event)) + self.assertFalse('set' in str(event)) + event.set() + self.assertFalse('clear' in str(event)) + self.assertTrue('set' in str(event)) + + def test_event(self): + e = locks.Event() + future_0 = e.wait() + e.set() + future_1 = e.wait() + e.clear() + future_2 = e.wait() + + self.assertTrue(future_0.done()) + self.assertTrue(future_1.done()) + self.assertFalse(future_2.done()) + + @gen_test + def test_event_timeout(self): + e = locks.Event() + with self.assertRaises(TimeoutError): + yield e.wait(timedelta(seconds=0.01)) + + # After a timed-out waiter, normal operation works. + self.io_loop.add_timeout(timedelta(seconds=0.01), e.set) + yield e.wait(timedelta(seconds=1)) + + def test_event_set_multiple(self): + e = locks.Event() + e.set() + e.set() + self.assertTrue(e.is_set()) + + def test_event_wait_clear(self): + e = locks.Event() + f0 = e.wait() + e.clear() + f1 = e.wait() + e.set() + self.assertTrue(f0.done()) + self.assertTrue(f1.done()) + + +class SemaphoreTest(AsyncTestCase): + def test_negative_value(self): + self.assertRaises(ValueError, locks.Semaphore, value=-1) + + def test_repr(self): + sem = locks.Semaphore() + self.assertIn('Semaphore', repr(sem)) + self.assertIn('unlocked,value:1', repr(sem)) + sem.acquire() + self.assertIn('locked', repr(sem)) + self.assertNotIn('waiters', repr(sem)) + sem.acquire() + self.assertIn('waiters', repr(sem)) + + def test_acquire(self): + sem = locks.Semaphore() + f0 = sem.acquire() + self.assertTrue(f0.done()) + + # Wait for release(). + f1 = sem.acquire() + self.assertFalse(f1.done()) + f2 = sem.acquire() + sem.release() + self.assertTrue(f1.done()) + self.assertFalse(f2.done()) + sem.release() + self.assertTrue(f2.done()) + + sem.release() + # Now acquire() is instant. + self.assertTrue(sem.acquire().done()) + self.assertEqual(0, len(sem._waiters)) + + @gen_test + def test_acquire_timeout(self): + sem = locks.Semaphore(2) + yield sem.acquire() + yield sem.acquire() + acquire = sem.acquire(timedelta(seconds=0.01)) + self.io_loop.call_later(0.02, sem.release) # Too late. + yield gen.sleep(0.3) + with self.assertRaises(gen.TimeoutError): + yield acquire + + sem.acquire() + f = sem.acquire() + self.assertFalse(f.done()) + sem.release() + self.assertTrue(f.done()) + + @gen_test + def test_acquire_timeout_preempted(self): + sem = locks.Semaphore(1) + yield sem.acquire() + + # This fires before the wait times out. + self.io_loop.call_later(0.01, sem.release) + acquire = sem.acquire(timedelta(seconds=0.02)) + yield gen.sleep(0.03) + yield acquire # No TimeoutError. + + def test_release_unacquired(self): + # Unbounded releases are allowed, and increment the semaphore's value. + sem = locks.Semaphore() + sem.release() + sem.release() + + # Now the counter is 3. We can acquire three times before blocking. + self.assertTrue(sem.acquire().done()) + self.assertTrue(sem.acquire().done()) + self.assertTrue(sem.acquire().done()) + self.assertFalse(sem.acquire().done()) + + @gen_test + def test_garbage_collection(self): + # Test that timed-out waiters are occasionally cleaned from the queue. + sem = locks.Semaphore(value=0) + futures = [sem.acquire(timedelta(seconds=0.01)) for _ in range(101)] + + future = sem.acquire() + self.assertEqual(102, len(sem._waiters)) + + # Let first 101 waiters time out, triggering a collection. + yield gen.sleep(0.02) + self.assertEqual(1, len(sem._waiters)) + + # Final waiter is still active. + self.assertFalse(future.done()) + sem.release() + self.assertTrue(future.done()) + + # Prevent "Future exception was never retrieved" messages. + for future in futures: + self.assertRaises(TimeoutError, future.result) + + +class SemaphoreContextManagerTest(AsyncTestCase): + @gen_test + def test_context_manager(self): + sem = locks.Semaphore() + with (yield sem.acquire()) as yielded: + self.assertTrue(yielded is None) + + # Semaphore was released and can be acquired again. + self.assertTrue(sem.acquire().done()) + + @gen_test + def test_context_manager_exception(self): + sem = locks.Semaphore() + with self.assertRaises(ZeroDivisionError): + with (yield sem.acquire()): + 1 / 0 + + # Semaphore was released and can be acquired again. + self.assertTrue(sem.acquire().done()) + + @gen_test + def test_context_manager_timeout(self): + sem = locks.Semaphore() + with (yield sem.acquire(timedelta(seconds=0.01))): + pass + + # Semaphore was released and can be acquired again. + self.assertTrue(sem.acquire().done()) + + @gen_test + def test_context_manager_timeout_error(self): + sem = locks.Semaphore(value=0) + with self.assertRaises(gen.TimeoutError): + with (yield sem.acquire(timedelta(seconds=0.01))): + pass + + # Counter is still 0. + self.assertFalse(sem.acquire().done()) + + @gen_test + def test_context_manager_contended(self): + sem = locks.Semaphore() + history = [] + + @gen.coroutine + def f(index): + with (yield sem.acquire()): + history.append('acquired %d' % index) + yield gen.sleep(0.01) + history.append('release %d' % index) + + yield [f(i) for i in range(2)] + + expected_history = [] + for i in range(2): + expected_history.extend(['acquired %d' % i, 'release %d' % i]) + + self.assertEqual(expected_history, history) + + @gen_test + def test_yield_sem(self): + # Ensure we catch a "with (yield sem)", which should be + # "with (yield sem.acquire())". + with self.assertRaises(gen.BadYieldError): + with (yield locks.Semaphore()): + pass + + def test_context_manager_misuse(self): + # Ensure we catch a "with sem", which should be + # "with (yield sem.acquire())". + with self.assertRaises(RuntimeError): + with locks.Semaphore(): + pass + + +class BoundedSemaphoreTest(AsyncTestCase): + def test_release_unacquired(self): + sem = locks.BoundedSemaphore() + self.assertRaises(ValueError, sem.release) + # Value is 0. + sem.acquire() + # Block on acquire(). + future = sem.acquire() + self.assertFalse(future.done()) + sem.release() + self.assertTrue(future.done()) + # Value is 1. + sem.release() + self.assertRaises(ValueError, sem.release) + + +class LockTests(AsyncTestCase): + def test_repr(self): + lock = locks.Lock() + # No errors. + repr(lock) + lock.acquire() + repr(lock) + + def test_acquire_release(self): + lock = locks.Lock() + self.assertTrue(lock.acquire().done()) + future = lock.acquire() + self.assertFalse(future.done()) + lock.release() + self.assertTrue(future.done()) + + @gen_test + def test_acquire_fifo(self): + lock = locks.Lock() + self.assertTrue(lock.acquire().done()) + N = 5 + history = [] + + @gen.coroutine + def f(idx): + with (yield lock.acquire()): + history.append(idx) + + futures = [f(i) for i in range(N)] + self.assertFalse(any(future.done() for future in futures)) + lock.release() + yield futures + self.assertEqual(list(range(N)), history) + + @gen_test + def test_acquire_timeout(self): + lock = locks.Lock() + lock.acquire() + with self.assertRaises(gen.TimeoutError): + yield lock.acquire(timeout=timedelta(seconds=0.01)) + + # Still locked. + self.assertFalse(lock.acquire().done()) + + def test_multi_release(self): + lock = locks.Lock() + self.assertRaises(RuntimeError, lock.release) + lock.acquire() + lock.release() + self.assertRaises(RuntimeError, lock.release) + + @gen_test + def test_yield_lock(self): + # Ensure we catch a "with (yield lock)", which should be + # "with (yield lock.acquire())". + with self.assertRaises(gen.BadYieldError): + with (yield locks.Lock()): + pass + + def test_context_manager_misuse(self): + # Ensure we catch a "with lock", which should be + # "with (yield lock.acquire())". + with self.assertRaises(RuntimeError): + with locks.Lock(): + pass + + +if __name__ == '__main__': + unittest.main() diff --git a/tornado/test/netutil_test.py b/tornado/test/netutil_test.py index 1df1e32042f3711e70d1cbf6867baee153401a37..7d9cad34a0491b579813dff135b7d8db560a86a0 100644 --- a/tornado/test/netutil_test.py +++ b/tornado/test/netutil_test.py @@ -67,10 +67,12 @@ class _ResolverErrorTestMixin(object): yield self.resolver.resolve('an invalid domain', 80, socket.AF_UNSPEC) + def _failing_getaddrinfo(*args): """Dummy implementation of getaddrinfo for use in mocks""" raise socket.gaierror("mock: lookup failed") + @skipIfNoNetwork class BlockingResolverTest(AsyncTestCase, _ResolverTestMixin): def setUp(self): diff --git a/tornado/test/process_test.py b/tornado/test/process_test.py index de727607ac85d15178196e9bd0c53fc092c9cab6..58cc410b68ab90542e9e0fa813649afb5c706d02 100644 --- a/tornado/test/process_test.py +++ b/tornado/test/process_test.py @@ -13,7 +13,7 @@ from tornado.ioloop import IOLoop from tornado.log import gen_log from tornado.process import fork_processes, task_id, Subprocess from tornado.simple_httpclient import SimpleAsyncHTTPClient -from tornado.testing import bind_unused_port, ExpectLog, AsyncTestCase +from tornado.testing import bind_unused_port, ExpectLog, AsyncTestCase, gen_test from tornado.test.util import unittest, skipIfNonUnix from tornado.web import RequestHandler, Application @@ -85,7 +85,7 @@ class ProcessTest(unittest.TestCase): self.assertEqual(id, task_id()) server = HTTPServer(self.get_app()) server.add_sockets([sock]) - IOLoop.instance().start() + IOLoop.current().start() elif id == 2: self.assertEqual(id, task_id()) sock.close() @@ -200,6 +200,16 @@ class SubprocessTest(AsyncTestCase): self.assertEqual(ret, 0) self.assertEqual(subproc.returncode, ret) + @gen_test + def test_sigchild_future(self): + skip_if_twisted() + Subprocess.initialize() + self.addCleanup(Subprocess.uninitialize) + subproc = Subprocess([sys.executable, '-c', 'pass']) + ret = yield subproc.wait_for_exit() + self.assertEqual(ret, 0) + self.assertEqual(subproc.returncode, ret) + def test_sigchild_signal(self): skip_if_twisted() Subprocess.initialize(io_loop=self.io_loop) @@ -212,3 +222,22 @@ class SubprocessTest(AsyncTestCase): ret = self.wait() self.assertEqual(subproc.returncode, ret) self.assertEqual(ret, -signal.SIGTERM) + + @gen_test + def test_wait_for_exit_raise(self): + skip_if_twisted() + Subprocess.initialize() + self.addCleanup(Subprocess.uninitialize) + subproc = Subprocess([sys.executable, '-c', 'import sys; sys.exit(1)']) + with self.assertRaises(subprocess.CalledProcessError) as cm: + yield subproc.wait_for_exit() + self.assertEqual(cm.exception.returncode, 1) + + @gen_test + def test_wait_for_exit_raise_disabled(self): + skip_if_twisted() + Subprocess.initialize() + self.addCleanup(Subprocess.uninitialize) + subproc = Subprocess([sys.executable, '-c', 'import sys; sys.exit(1)']) + ret = yield subproc.wait_for_exit(raise_error=False) + self.assertEqual(ret, 1) diff --git a/tornado/test/queues_test.py b/tornado/test/queues_test.py new file mode 100644 index 0000000000000000000000000000000000000000..f2ffb646f0c94a192f1dac28fd6084a0aaca0b6f --- /dev/null +++ b/tornado/test/queues_test.py @@ -0,0 +1,403 @@ +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +from datetime import timedelta +from random import random + +from tornado import gen, queues +from tornado.gen import TimeoutError +from tornado.testing import gen_test, AsyncTestCase +from tornado.test.util import unittest + + +class QueueBasicTest(AsyncTestCase): + def test_repr_and_str(self): + q = queues.Queue(maxsize=1) + self.assertIn(hex(id(q)), repr(q)) + self.assertNotIn(hex(id(q)), str(q)) + q.get() + + for q_str in repr(q), str(q): + self.assertTrue(q_str.startswith('<Queue')) + self.assertIn('maxsize=1', q_str) + self.assertIn('getters[1]', q_str) + self.assertNotIn('putters', q_str) + self.assertNotIn('tasks', q_str) + + q.put(None) + q.put(None) + # Now the queue is full, this putter blocks. + q.put(None) + + for q_str in repr(q), str(q): + self.assertNotIn('getters', q_str) + self.assertIn('putters[1]', q_str) + self.assertIn('tasks=2', q_str) + + def test_order(self): + q = queues.Queue() + for i in [1, 3, 2]: + q.put_nowait(i) + + items = [q.get_nowait() for _ in range(3)] + self.assertEqual([1, 3, 2], items) + + @gen_test + def test_maxsize(self): + self.assertRaises(TypeError, queues.Queue, maxsize=None) + self.assertRaises(ValueError, queues.Queue, maxsize=-1) + + q = queues.Queue(maxsize=2) + self.assertTrue(q.empty()) + self.assertFalse(q.full()) + self.assertEqual(2, q.maxsize) + self.assertTrue(q.put(0).done()) + self.assertTrue(q.put(1).done()) + self.assertFalse(q.empty()) + self.assertTrue(q.full()) + put2 = q.put(2) + self.assertFalse(put2.done()) + self.assertEqual(0, (yield q.get())) # Make room. + self.assertTrue(put2.done()) + self.assertFalse(q.empty()) + self.assertTrue(q.full()) + + +class QueueGetTest(AsyncTestCase): + @gen_test + def test_blocking_get(self): + q = queues.Queue() + q.put_nowait(0) + self.assertEqual(0, (yield q.get())) + + def test_nonblocking_get(self): + q = queues.Queue() + q.put_nowait(0) + self.assertEqual(0, q.get_nowait()) + + def test_nonblocking_get_exception(self): + q = queues.Queue() + self.assertRaises(queues.QueueEmpty, q.get_nowait) + + @gen_test + def test_get_with_putters(self): + q = queues.Queue(1) + q.put_nowait(0) + put = q.put(1) + self.assertEqual(0, (yield q.get())) + self.assertIsNone((yield put)) + + @gen_test + def test_blocking_get_wait(self): + q = queues.Queue() + q.put(0) + self.io_loop.call_later(0.01, q.put, 1) + self.io_loop.call_later(0.02, q.put, 2) + self.assertEqual(0, (yield q.get(timeout=timedelta(seconds=1)))) + self.assertEqual(1, (yield q.get(timeout=timedelta(seconds=1)))) + + @gen_test + def test_get_timeout(self): + q = queues.Queue() + get_timeout = q.get(timeout=timedelta(seconds=0.01)) + get = q.get() + with self.assertRaises(TimeoutError): + yield get_timeout + + q.put_nowait(0) + self.assertEqual(0, (yield get)) + + @gen_test + def test_get_timeout_preempted(self): + q = queues.Queue() + get = q.get(timeout=timedelta(seconds=0.01)) + q.put(0) + yield gen.sleep(0.02) + self.assertEqual(0, (yield get)) + + @gen_test + def test_get_clears_timed_out_putters(self): + q = queues.Queue(1) + # First putter succeeds, remainder block. + putters = [q.put(i, timedelta(seconds=0.01)) for i in range(10)] + put = q.put(10) + self.assertEqual(10, len(q._putters)) + yield gen.sleep(0.02) + self.assertEqual(10, len(q._putters)) + self.assertFalse(put.done()) # Final waiter is still active. + q.put(11) + self.assertEqual(0, (yield q.get())) # get() clears the waiters. + self.assertEqual(1, len(q._putters)) + for putter in putters[1:]: + self.assertRaises(TimeoutError, putter.result) + + @gen_test + def test_get_clears_timed_out_getters(self): + q = queues.Queue() + getters = [q.get(timedelta(seconds=0.01)) for _ in range(10)] + get = q.get() + self.assertEqual(11, len(q._getters)) + yield gen.sleep(0.02) + self.assertEqual(11, len(q._getters)) + self.assertFalse(get.done()) # Final waiter is still active. + q.get() # get() clears the waiters. + self.assertEqual(2, len(q._getters)) + for getter in getters: + self.assertRaises(TimeoutError, getter.result) + + +class QueuePutTest(AsyncTestCase): + @gen_test + def test_blocking_put(self): + q = queues.Queue() + q.put(0) + self.assertEqual(0, q.get_nowait()) + + def test_nonblocking_put_exception(self): + q = queues.Queue(1) + q.put(0) + self.assertRaises(queues.QueueFull, q.put_nowait, 1) + + @gen_test + def test_put_with_getters(self): + q = queues.Queue() + get0 = q.get() + get1 = q.get() + yield q.put(0) + self.assertEqual(0, (yield get0)) + yield q.put(1) + self.assertEqual(1, (yield get1)) + + @gen_test + def test_nonblocking_put_with_getters(self): + q = queues.Queue() + get0 = q.get() + get1 = q.get() + q.put_nowait(0) + # put_nowait does *not* immediately unblock getters. + yield gen.moment + self.assertEqual(0, (yield get0)) + q.put_nowait(1) + yield gen.moment + self.assertEqual(1, (yield get1)) + + @gen_test + def test_blocking_put_wait(self): + q = queues.Queue(1) + q.put_nowait(0) + self.io_loop.call_later(0.01, q.get) + self.io_loop.call_later(0.02, q.get) + futures = [q.put(0), q.put(1)] + self.assertFalse(any(f.done() for f in futures)) + yield futures + + @gen_test + def test_put_timeout(self): + q = queues.Queue(1) + q.put_nowait(0) # Now it's full. + put_timeout = q.put(1, timeout=timedelta(seconds=0.01)) + put = q.put(2) + with self.assertRaises(TimeoutError): + yield put_timeout + + self.assertEqual(0, q.get_nowait()) + # 1 was never put in the queue. + self.assertEqual(2, (yield q.get())) + + # Final get() unblocked this putter. + yield put + + @gen_test + def test_put_timeout_preempted(self): + q = queues.Queue(1) + q.put_nowait(0) + put = q.put(1, timeout=timedelta(seconds=0.01)) + q.get() + yield gen.sleep(0.02) + yield put # No TimeoutError. + + @gen_test + def test_put_clears_timed_out_putters(self): + q = queues.Queue(1) + # First putter succeeds, remainder block. + putters = [q.put(i, timedelta(seconds=0.01)) for i in range(10)] + put = q.put(10) + self.assertEqual(10, len(q._putters)) + yield gen.sleep(0.02) + self.assertEqual(10, len(q._putters)) + self.assertFalse(put.done()) # Final waiter is still active. + q.put(11) # put() clears the waiters. + self.assertEqual(2, len(q._putters)) + for putter in putters[1:]: + self.assertRaises(TimeoutError, putter.result) + + @gen_test + def test_put_clears_timed_out_getters(self): + q = queues.Queue() + getters = [q.get(timedelta(seconds=0.01)) for _ in range(10)] + get = q.get() + q.get() + self.assertEqual(12, len(q._getters)) + yield gen.sleep(0.02) + self.assertEqual(12, len(q._getters)) + self.assertFalse(get.done()) # Final waiters still active. + q.put(0) # put() clears the waiters. + self.assertEqual(1, len(q._getters)) + self.assertEqual(0, (yield get)) + for getter in getters: + self.assertRaises(TimeoutError, getter.result) + + @gen_test + def test_float_maxsize(self): + # Non-int maxsize must round down: http://bugs.python.org/issue21723 + q = queues.Queue(maxsize=1.3) + self.assertTrue(q.empty()) + self.assertFalse(q.full()) + q.put_nowait(0) + q.put_nowait(1) + self.assertFalse(q.empty()) + self.assertTrue(q.full()) + self.assertRaises(queues.QueueFull, q.put_nowait, 2) + self.assertEqual(0, q.get_nowait()) + self.assertFalse(q.empty()) + self.assertFalse(q.full()) + + yield q.put(2) + put = q.put(3) + self.assertFalse(put.done()) + self.assertEqual(1, (yield q.get())) + yield put + self.assertTrue(q.full()) + + +class QueueJoinTest(AsyncTestCase): + queue_class = queues.Queue + + def test_task_done_underflow(self): + q = self.queue_class() + self.assertRaises(ValueError, q.task_done) + + @gen_test + def test_task_done(self): + q = self.queue_class() + for i in range(100): + q.put_nowait(i) + + self.accumulator = 0 + + @gen.coroutine + def worker(): + while True: + item = yield q.get() + self.accumulator += item + q.task_done() + yield gen.sleep(random() * 0.01) + + # Two coroutines share work. + worker() + worker() + yield q.join() + self.assertEqual(sum(range(100)), self.accumulator) + + @gen_test + def test_task_done_delay(self): + # Verify it is task_done(), not get(), that unblocks join(). + q = self.queue_class() + q.put_nowait(0) + join = q.join() + self.assertFalse(join.done()) + yield q.get() + self.assertFalse(join.done()) + yield gen.moment + self.assertFalse(join.done()) + q.task_done() + self.assertTrue(join.done()) + + @gen_test + def test_join_empty_queue(self): + q = self.queue_class() + yield q.join() + yield q.join() + + @gen_test + def test_join_timeout(self): + q = self.queue_class() + q.put(0) + with self.assertRaises(TimeoutError): + yield q.join(timeout=timedelta(seconds=0.01)) + + +class PriorityQueueJoinTest(QueueJoinTest): + queue_class = queues.PriorityQueue + + @gen_test + def test_order(self): + q = self.queue_class(maxsize=2) + q.put_nowait((1, 'a')) + q.put_nowait((0, 'b')) + self.assertTrue(q.full()) + q.put((3, 'c')) + q.put((2, 'd')) + self.assertEqual((0, 'b'), q.get_nowait()) + self.assertEqual((1, 'a'), (yield q.get())) + self.assertEqual((2, 'd'), q.get_nowait()) + self.assertEqual((3, 'c'), (yield q.get())) + self.assertTrue(q.empty()) + + +class LifoQueueJoinTest(QueueJoinTest): + queue_class = queues.LifoQueue + + @gen_test + def test_order(self): + q = self.queue_class(maxsize=2) + q.put_nowait(1) + q.put_nowait(0) + self.assertTrue(q.full()) + q.put(3) + q.put(2) + self.assertEqual(3, q.get_nowait()) + self.assertEqual(2, (yield q.get())) + self.assertEqual(0, q.get_nowait()) + self.assertEqual(1, (yield q.get())) + self.assertTrue(q.empty()) + + +class ProducerConsumerTest(AsyncTestCase): + @gen_test + def test_producer_consumer(self): + q = queues.Queue(maxsize=3) + history = [] + + # We don't yield between get() and task_done(), so get() must wait for + # the next tick. Otherwise we'd immediately call task_done and unblock + # join() before q.put() resumes, and we'd only process the first four + # items. + @gen.coroutine + def consumer(): + while True: + history.append((yield q.get())) + q.task_done() + + @gen.coroutine + def producer(): + for item in range(10): + yield q.put(item) + + consumer() + yield producer() + yield q.join() + self.assertEqual(list(range(10)), history) + + +if __name__ == '__main__': + unittest.main() diff --git a/tornado/test/runtests.py b/tornado/test/runtests.py index acbb5695e2a11bdd90985c2acae82828f5104a1e..ad9b0b8357be9271158502ba42d1dcc464e2763b 100644 --- a/tornado/test/runtests.py +++ b/tornado/test/runtests.py @@ -8,6 +8,7 @@ import operator import textwrap import sys from tornado.httpclient import AsyncHTTPClient +from tornado.httpserver import HTTPServer from tornado.ioloop import IOLoop from tornado.netutil import Resolver from tornado.options import define, options, add_parse_callback @@ -35,13 +36,16 @@ TEST_MODULES = [ 'tornado.test.ioloop_test', 'tornado.test.iostream_test', 'tornado.test.locale_test', + 'tornado.test.locks_test', 'tornado.test.netutil_test', 'tornado.test.log_test', 'tornado.test.options_test', 'tornado.test.process_test', + 'tornado.test.queues_test', 'tornado.test.simple_httpclient_test', 'tornado.test.stack_context_test', 'tornado.test.tcpclient_test', + 'tornado.test.tcpserver_test', 'tornado.test.template_test', 'tornado.test.testing_test', 'tornado.test.twisted_test', @@ -121,6 +125,8 @@ def main(): define('httpclient', type=str, default=None, callback=lambda s: AsyncHTTPClient.configure( s, defaults=dict(allow_ipv6=False))) + define('httpserver', type=str, default=None, + callback=HTTPServer.configure) define('ioloop', type=str, default=None) define('ioloop_time_monotonic', default=False) define('resolver', type=str, default=None, diff --git a/tornado/test/simple_httpclient_test.py b/tornado/test/simple_httpclient_test.py index bb870db3b0308d109687ff78241b0c3edb7b67bb..c0de22b7cf08cdc186ad0c0b57c99e7186f541cf 100644 --- a/tornado/test/simple_httpclient_test.py +++ b/tornado/test/simple_httpclient_test.py @@ -8,11 +8,12 @@ import logging import os import re import socket +import ssl import sys from tornado import gen from tornado.httpclient import AsyncHTTPClient -from tornado.httputil import HTTPHeaders +from tornado.httputil import HTTPHeaders, ResponseStartLine from tornado.ioloop import IOLoop from tornado.log import gen_log from tornado.netutil import Resolver, bind_sockets @@ -20,7 +21,7 @@ from tornado.simple_httpclient import SimpleAsyncHTTPClient, _default_ca_certs from tornado.test.httpclient_test import ChunkHandler, CountdownHandler, HelloWorldHandler from tornado.test import httpclient_test from tornado.testing import AsyncHTTPTestCase, AsyncHTTPSTestCase, AsyncTestCase, ExpectLog -from tornado.test.util import skipOnTravis, skipIfNoIPv6, refusing_port +from tornado.test.util import skipOnTravis, skipIfNoIPv6, refusing_port, unittest from tornado.web import RequestHandler, Application, asynchronous, url, stream_request_body @@ -97,15 +98,18 @@ class HostEchoHandler(RequestHandler): class NoContentLengthHandler(RequestHandler): - @gen.coroutine + @asynchronous def get(self): - # Emulate the old HTTP/1.0 behavior of returning a body with no - # content-length. Tornado handles content-length at the framework - # level so we have to go around it. - stream = self.request.connection.stream - yield stream.write(b"HTTP/1.0 200 OK\r\n\r\n" - b"hello") - stream.close() + if self.request.version.startswith('HTTP/1'): + # Emulate the old HTTP/1.0 behavior of returning a body with no + # content-length. Tornado handles content-length at the framework + # level so we have to go around it. + stream = self.request.connection.detach() + stream.write(b"HTTP/1.0 200 OK\r\n\r\n" + b"hello") + stream.close() + else: + self.finish('HTTP/1 required') class EchoPostHandler(RequestHandler): @@ -191,9 +195,6 @@ class SimpleHTTPClientTestMixin(object): response = self.wait() response.rethrow() - def test_default_certificates_exist(self): - open(_default_ca_certs()).close() - def test_gzip(self): # All the tests in this file should be using gzip, but this test # ensures that it is in fact getting compressed. @@ -359,7 +360,10 @@ class SimpleHTTPClientTestMixin(object): def test_no_content_length(self): response = self.fetch("/no_content_length") - self.assertEquals(b"hello", response.body) + if response.body == b"HTTP/1 required": + self.skipTest("requires HTTP/1.x") + else: + self.assertEquals(b"hello", response.body) def sync_body_producer(self, write): write(b'1234') @@ -432,6 +436,33 @@ class SimpleHTTPSClientTestCase(SimpleHTTPClientTestMixin, AsyncHTTPSTestCase): defaults=dict(validate_cert=False), **kwargs) + def test_ssl_options(self): + resp = self.fetch("/hello", ssl_options={}) + self.assertEqual(resp.body, b"Hello world!") + + @unittest.skipIf(not hasattr(ssl, 'SSLContext'), + 'ssl.SSLContext not present') + def test_ssl_context(self): + resp = self.fetch("/hello", + ssl_options=ssl.SSLContext(ssl.PROTOCOL_SSLv23)) + self.assertEqual(resp.body, b"Hello world!") + + def test_ssl_options_handshake_fail(self): + with ExpectLog(gen_log, "SSL Error|Uncaught exception", + required=False): + resp = self.fetch( + "/hello", ssl_options=dict(cert_reqs=ssl.CERT_REQUIRED)) + self.assertRaises(ssl.SSLError, resp.rethrow) + + @unittest.skipIf(not hasattr(ssl, 'SSLContext'), + 'ssl.SSLContext not present') + def test_ssl_context_handshake_fail(self): + with ExpectLog(gen_log, "SSL Error|Uncaught exception"): + ctx = ssl.SSLContext(ssl.PROTOCOL_SSLv23) + ctx.verify_mode = ssl.CERT_REQUIRED + resp = self.fetch("/hello", ssl_options=ctx) + self.assertRaises(ssl.SSLError, resp.rethrow) + class CreateAsyncHTTPClientTestCase(AsyncTestCase): def setUp(self): @@ -467,6 +498,12 @@ class CreateAsyncHTTPClientTestCase(AsyncTestCase): class HTTP100ContinueTestCase(AsyncHTTPTestCase): def respond_100(self, request): + self.http1 = request.version.startswith('HTTP/1.') + if not self.http1: + request.connection.write_headers(ResponseStartLine('', 200, 'OK'), + HTTPHeaders()) + request.connection.finish() + return self.request = request self.request.connection.stream.write( b"HTTP/1.1 100 CONTINUE\r\n\r\n", @@ -483,11 +520,20 @@ class HTTP100ContinueTestCase(AsyncHTTPTestCase): def test_100_continue(self): res = self.fetch('/') + if not self.http1: + self.skipTest("requires HTTP/1.x") self.assertEqual(res.body, b'A') class HTTP204NoContentTestCase(AsyncHTTPTestCase): def respond_204(self, request): + self.http1 = request.version.startswith('HTTP/1.') + if not self.http1: + # Close the request cleanly in HTTP/2; it will be skipped anyway. + request.connection.write_headers(ResponseStartLine('', 200, 'OK'), + HTTPHeaders()) + request.connection.finish() + return # A 204 response never has a body, even if doesn't have a content-length # (which would otherwise mean read-until-close). Tornado always # sends a content-length, so we simulate here a server that sends @@ -495,14 +541,18 @@ class HTTP204NoContentTestCase(AsyncHTTPTestCase): # # Tests of a 204 response with a Content-Length header are included # in SimpleHTTPClientTestMixin. - request.connection.stream.write( + stream = request.connection.detach() + stream.write( b"HTTP/1.1 204 No content\r\n\r\n") + stream.close() def get_app(self): return self.respond_204 def test_204_no_content(self): resp = self.fetch('/') + if not self.http1: + self.skipTest("requires HTTP/1.x") self.assertEqual(resp.code, 204) self.assertEqual(resp.body, b'') @@ -581,3 +631,49 @@ class MaxHeaderSizeTest(AsyncHTTPTestCase): with ExpectLog(gen_log, "Unsatisfiable read"): response = self.fetch('/large') self.assertEqual(response.code, 599) + + +class MaxBodySizeTest(AsyncHTTPTestCase): + def get_app(self): + class SmallBody(RequestHandler): + def get(self): + self.write("a"*1024*64) + + class LargeBody(RequestHandler): + def get(self): + self.write("a"*1024*100) + + return Application([('/small', SmallBody), + ('/large', LargeBody)]) + + def get_http_client(self): + return SimpleAsyncHTTPClient(io_loop=self.io_loop, max_body_size=1024*64) + + def test_small_body(self): + response = self.fetch('/small') + response.rethrow() + self.assertEqual(response.body, b'a'*1024*64) + + def test_large_body(self): + with ExpectLog(gen_log, "Malformed HTTP message from None: Content-Length too long"): + response = self.fetch('/large') + self.assertEqual(response.code, 599) + + +class MaxBufferSizeTest(AsyncHTTPTestCase): + def get_app(self): + + class LargeBody(RequestHandler): + def get(self): + self.write("a"*1024*100) + + return Application([('/large', LargeBody)]) + + def get_http_client(self): + # 100KB body with 64KB buffer + return SimpleAsyncHTTPClient(io_loop=self.io_loop, max_body_size=1024*100, max_buffer_size=1024*64) + + def test_large_body(self): + response = self.fetch('/large') + response.rethrow() + self.assertEqual(response.body, b'a'*1024*100) diff --git a/tornado/test/static_foo.txt b/tornado/test/static_foo.txt new file mode 100644 index 0000000000000000000000000000000000000000..bdb44f39184e5d5f85a73eb2405e7c307da03ec3 --- /dev/null +++ b/tornado/test/static_foo.txt @@ -0,0 +1,2 @@ +This file should not be served by StaticFileHandler even though +its name starts with "static". diff --git a/tornado/test/tcpserver_test.py b/tornado/test/tcpserver_test.py new file mode 100644 index 0000000000000000000000000000000000000000..84c950769e4a9a31ec91d287355643759d711265 --- /dev/null +++ b/tornado/test/tcpserver_test.py @@ -0,0 +1,38 @@ +import socket + +from tornado import gen +from tornado.iostream import IOStream +from tornado.log import app_log +from tornado.stack_context import NullContext +from tornado.tcpserver import TCPServer +from tornado.testing import AsyncTestCase, ExpectLog, bind_unused_port, gen_test + + +class TCPServerTest(AsyncTestCase): + @gen_test + def test_handle_stream_coroutine_logging(self): + # handle_stream may be a coroutine and any exception in its + # Future will be logged. + class TestServer(TCPServer): + @gen.coroutine + def handle_stream(self, stream, address): + yield gen.moment + stream.close() + 1/0 + + server = client = None + try: + sock, port = bind_unused_port() + with NullContext(): + server = TestServer() + server.add_socket(sock) + client = IOStream(socket.socket()) + with ExpectLog(app_log, "Exception in callback"): + yield client.connect(('localhost', port)) + yield client.read_until_close() + yield gen.moment + finally: + if server is not None: + server.stop() + if client is not None: + client.close() diff --git a/tornado/test/twisted_test.py b/tornado/test/twisted_test.py index b31ae94cb90e5e4f01be1a005291bac36adb5f53..22410567a8955f0dff2601ad8614fbb12358d37b 100644 --- a/tornado/test/twisted_test.py +++ b/tornado/test/twisted_test.py @@ -75,6 +75,7 @@ skipIfNoTwisted = unittest.skipUnless(have_twisted, skipIfNoSingleDispatch = unittest.skipIf( gen.singledispatch is None, "singledispatch module not present") + def save_signal_handlers(): saved = {} for sig in [signal.SIGINT, signal.SIGTERM, signal.SIGCHLD]: @@ -452,6 +453,7 @@ class CompatibilityTests(unittest.TestCase): def twisted_coroutine_fetch(self, url, runner): body = [None] + @gen.coroutine def f(): # This is simpler than the non-coroutine version, but it cheats @@ -552,7 +554,7 @@ if have_twisted: 'test_changeUID', ], # Process tests appear to work on OSX 10.7, but not 10.6 - #'twisted.internet.test.test_process.PTYProcessTestsBuilder': [ + # 'twisted.internet.test.test_process.PTYProcessTestsBuilder': [ # 'test_systemCallUninterruptedByChildExit', # ], 'twisted.internet.test.test_tcp.TCPClientTestsBuilder': [ @@ -571,7 +573,7 @@ if have_twisted: 'twisted.internet.test.test_threads.ThreadTestsBuilder': [], 'twisted.internet.test.test_time.TimeTestsBuilder': [], # Extra third-party dependencies (pyOpenSSL) - #'twisted.internet.test.test_tls.SSLClientTestsMixin': [], + # 'twisted.internet.test.test_tls.SSLClientTestsMixin': [], 'twisted.internet.test.test_udp.UDPServerTestsBuilder': [], 'twisted.internet.test.test_unix.UNIXTestsBuilder': [ # Platform-specific. These tests would be skipped automatically @@ -591,6 +593,14 @@ if have_twisted: ], 'twisted.internet.test.test_unix.UNIXPortTestsBuilder': [], } + if sys.version_info >= (3,): + # In Twisted 15.2.0 on Python 3.4, the process tests will try to run + # but fail, due in part to interactions between Tornado's strict + # warnings-as-errors policy and Twisted's own warning handling + # (it was not obvious how to configure the warnings module to + # reconcile the two), and partly due to what looks like a packaging + # error (process_cli.py missing). For now, just skip it. + del twisted_tests['twisted.internet.test.test_process.ProcessTestsBuilder'] for test_name, blacklist in twisted_tests.items(): try: test_class = import_object(test_name) @@ -657,13 +667,13 @@ if have_twisted: correctly. In some tests another TornadoReactor is layered on top of the whole stack. """ - def initialize(self): + def initialize(self, **kwargs): # When configured to use LayeredTwistedIOLoop we can't easily # get the next-best IOLoop implementation, so use the lowest common # denominator. self.real_io_loop = SelectIOLoop() reactor = TornadoReactor(io_loop=self.real_io_loop) - super(LayeredTwistedIOLoop, self).initialize(reactor=reactor) + super(LayeredTwistedIOLoop, self).initialize(reactor=reactor, **kwargs) self.add_callback(self.make_current) def close(self, all_fds=False): diff --git a/tornado/test/util.py b/tornado/test/util.py index 358809f216e247014b2401089c9744a5cde336cc..9dd9c0ce12d42b099b708fe1033c9edcfda70829 100644 --- a/tornado/test/util.py +++ b/tornado/test/util.py @@ -31,6 +31,7 @@ skipIfNoNetwork = unittest.skipIf('NO_NETWORK' in os.environ, skipIfNoIPv6 = unittest.skipIf(not socket.has_ipv6, 'ipv6 support not present') + def refusing_port(): """Returns a local port number that will refuse all connections. diff --git a/tornado/test/util_test.py b/tornado/test/util_test.py index 1cd78fe46f3b5672e52940e91240e32b929fd102..0936c89ad1792eda892e805a3bafdfbd8caa7914 100644 --- a/tornado/test/util_test.py +++ b/tornado/test/util_test.py @@ -3,8 +3,9 @@ from __future__ import absolute_import, division, print_function, with_statement import sys import datetime +import tornado.escape from tornado.escape import utf8 -from tornado.util import raise_exc_info, Configurable, u, exec_in, ArgReplacer, timedelta_to_seconds +from tornado.util import raise_exc_info, Configurable, u, exec_in, ArgReplacer, timedelta_to_seconds, import_object from tornado.test.util import unittest try: @@ -45,13 +46,15 @@ class TestConfigurable(Configurable): class TestConfig1(TestConfigurable): - def initialize(self, a=None): + def initialize(self, pos_arg=None, a=None): self.a = a + self.pos_arg = pos_arg class TestConfig2(TestConfigurable): - def initialize(self, b=None): + def initialize(self, pos_arg=None, b=None): self.b = b + self.pos_arg = pos_arg class ConfigurableTest(unittest.TestCase): @@ -101,9 +104,10 @@ class ConfigurableTest(unittest.TestCase): self.assertIsInstance(obj, TestConfig1) self.assertEqual(obj.a, 3) - obj = TestConfigurable(a=4) + obj = TestConfigurable(42, a=4) self.assertIsInstance(obj, TestConfig1) self.assertEqual(obj.a, 4) + self.assertEqual(obj.pos_arg, 42) self.checkSubclasses() # args bound in configure don't apply when using the subclass directly @@ -116,9 +120,10 @@ class ConfigurableTest(unittest.TestCase): self.assertIsInstance(obj, TestConfig2) self.assertEqual(obj.b, 5) - obj = TestConfigurable(b=6) + obj = TestConfigurable(42, b=6) self.assertIsInstance(obj, TestConfig2) self.assertEqual(obj.b, 6) + self.assertEqual(obj.pos_arg, 42) self.checkSubclasses() # args bound in configure don't apply when using the subclass directly @@ -177,3 +182,20 @@ class TimedeltaToSecondsTest(unittest.TestCase): def test_timedelta_to_seconds(self): time_delta = datetime.timedelta(hours=1) self.assertEqual(timedelta_to_seconds(time_delta), 3600.0) + + +class ImportObjectTest(unittest.TestCase): + def test_import_member(self): + self.assertIs(import_object('tornado.escape.utf8'), utf8) + + def test_import_member_unicode(self): + self.assertIs(import_object(u('tornado.escape.utf8')), utf8) + + def test_import_module(self): + self.assertIs(import_object('tornado.escape'), tornado.escape) + + def test_import_module_unicode(self): + # The internal implementation of __import__ differs depending on + # whether the thing being imported is a module or not. + # This variant requires a byte string in python 2. + self.assertIs(import_object(u('tornado.escape')), tornado.escape) diff --git a/tornado/test/web_test.py b/tornado/test/web_test.py index 77ad388812de95585673b261e6b608b8f41a25d3..9374c4824b378050fd3b2c7d6c34c4b8e844004d 100644 --- a/tornado/test/web_test.py +++ b/tornado/test/web_test.py @@ -11,7 +11,7 @@ from tornado.template import DictLoader from tornado.testing import AsyncHTTPTestCase, ExpectLog, gen_test from tornado.test.util import unittest from tornado.util import u, ObjectDict, unicode_type, timedelta_to_seconds -from tornado.web import RequestHandler, authenticated, Application, asynchronous, url, HTTPError, StaticFileHandler, _create_signature_v1, create_signed_value, decode_signed_value, ErrorHandler, UIModule, MissingArgumentError, stream_request_body, Finish, removeslash, addslash, RedirectHandler as WebRedirectHandler +from tornado.web import RequestHandler, authenticated, Application, asynchronous, url, HTTPError, StaticFileHandler, _create_signature_v1, create_signed_value, decode_signed_value, ErrorHandler, UIModule, MissingArgumentError, stream_request_body, Finish, removeslash, addslash, RedirectHandler as WebRedirectHandler, get_signature_key_version import binascii import contextlib @@ -71,10 +71,14 @@ class HelloHandler(RequestHandler): class CookieTestRequestHandler(RequestHandler): # stub out enough methods to make the secure_cookie functions work - def __init__(self): + def __init__(self, cookie_secret='0123456789', key_version=None): # don't call super.__init__ self._cookies = {} - self.application = ObjectDict(settings=dict(cookie_secret='0123456789')) + if key_version is None: + self.application = ObjectDict(settings=dict(cookie_secret=cookie_secret)) + else: + self.application = ObjectDict(settings=dict(cookie_secret=cookie_secret, + key_version=key_version)) def get_cookie(self, name): return self._cookies.get(name) @@ -128,6 +132,51 @@ class SecureCookieV1Test(unittest.TestCase): self.assertEqual(handler.get_secure_cookie('foo', min_version=1), b'\xe9') +# See SignedValueTest below for more. +class SecureCookieV2Test(unittest.TestCase): + KEY_VERSIONS = { + 0: 'ajklasdf0ojaisdf', + 1: 'aslkjasaolwkjsdf' + } + + def test_round_trip(self): + handler = CookieTestRequestHandler() + handler.set_secure_cookie('foo', b'bar', version=2) + self.assertEqual(handler.get_secure_cookie('foo', min_version=2), b'bar') + + def test_key_version_roundtrip(self): + handler = CookieTestRequestHandler(cookie_secret=self.KEY_VERSIONS, + key_version=0) + handler.set_secure_cookie('foo', b'bar') + self.assertEqual(handler.get_secure_cookie('foo'), b'bar') + + def test_key_version_roundtrip_differing_version(self): + handler = CookieTestRequestHandler(cookie_secret=self.KEY_VERSIONS, + key_version=1) + handler.set_secure_cookie('foo', b'bar') + self.assertEqual(handler.get_secure_cookie('foo'), b'bar') + + def test_key_version_increment_version(self): + handler = CookieTestRequestHandler(cookie_secret=self.KEY_VERSIONS, + key_version=0) + handler.set_secure_cookie('foo', b'bar') + new_handler = CookieTestRequestHandler(cookie_secret=self.KEY_VERSIONS, + key_version=1) + new_handler._cookies = handler._cookies + self.assertEqual(new_handler.get_secure_cookie('foo'), b'bar') + + def test_key_version_invalidate_version(self): + handler = CookieTestRequestHandler(cookie_secret=self.KEY_VERSIONS, + key_version=0) + handler.set_secure_cookie('foo', b'bar') + new_key_versions = self.KEY_VERSIONS.copy() + new_key_versions.pop(0) + new_handler = CookieTestRequestHandler(cookie_secret=new_key_versions, + key_version=1) + new_handler._cookies = handler._cookies + self.assertEqual(new_handler.get_secure_cookie('foo'), None) + + class CookieTest(WebTestCase): def get_handlers(self): class SetCookieHandler(RequestHandler): @@ -171,6 +220,13 @@ class CookieTest(WebTestCase): def get(self): self.set_cookie("foo", "bar", expires_days=10) + class SetCookieFalsyFlags(RequestHandler): + def get(self): + self.set_cookie("a", "1", secure=True) + self.set_cookie("b", "1", secure=False) + self.set_cookie("c", "1", httponly=True) + self.set_cookie("d", "1", httponly=False) + return [("/set", SetCookieHandler), ("/get", GetCookieHandler), ("/set_domain", SetCookieDomainHandler), @@ -178,6 +234,7 @@ class CookieTest(WebTestCase): ("/set_overwrite", SetCookieOverwriteHandler), ("/set_max_age", SetCookieMaxAgeHandler), ("/set_expires_days", SetCookieExpiresDaysHandler), + ("/set_falsy_flags", SetCookieFalsyFlags) ] def test_set_cookie(self): @@ -237,7 +294,7 @@ class CookieTest(WebTestCase): headers = response.headers.get_list("Set-Cookie") self.assertEqual(sorted(headers), ["foo=bar; Max-Age=10; Path=/"]) - + def test_set_cookie_expires_days(self): response = self.fetch("/set_expires_days") header = response.headers.get("Set-Cookie") @@ -248,7 +305,17 @@ class CookieTest(WebTestCase): header_expires = datetime.datetime( *email.utils.parsedate(match.groupdict()["expires"])[:6]) self.assertTrue(abs(timedelta_to_seconds(expires - header_expires)) < 10) - + + def test_set_cookie_false_flags(self): + response = self.fetch("/set_falsy_flags") + headers = sorted(response.headers.get_list("Set-Cookie")) + # The secure and httponly headers are capitalized in py35 and + # lowercase in older versions. + self.assertEqual(headers[0].lower(), 'a=1; path=/; secure') + self.assertEqual(headers[1].lower(), 'b=1; path=/') + self.assertEqual(headers[2].lower(), 'c=1; httponly; path=/') + self.assertEqual(headers[3].lower(), 'd=1; path=/') + class AuthRedirectRequestHandler(RequestHandler): def initialize(self, login_url): @@ -379,6 +446,12 @@ class RequestEncodingTest(WebTestCase): path_args=["a/b", "c/d"], args={})) + def test_error(self): + # Percent signs (encoded as %25) should not mess up printf-style + # messages in logs + with ExpectLog(gen_log, ".*Invalid unicode"): + self.fetch("/group/?arg=%25%e9") + class TypeCheckHandler(RequestHandler): def prepare(self): @@ -579,6 +652,7 @@ class WSGISafeWebTest(WebTestCase): url("/redirect", RedirectHandler), url("/web_redirect_permanent", WebRedirectHandler, {"url": "/web_redirect_newpath"}), url("/web_redirect", WebRedirectHandler, {"url": "/web_redirect_newpath", "permanent": False}), + url("//web_redirect_double_slash", WebRedirectHandler, {"url": '/web_redirect_newpath'}), url("/header_injection", HeaderInjectionHandler), url("/get_argument", GetArgumentHandler), url("/get_arguments", GetArgumentsHandler), @@ -712,6 +786,11 @@ js_embed() self.assertEqual(response.code, 302) self.assertEqual(response.headers['Location'], '/web_redirect_newpath') + def test_web_redirect_double_slash(self): + response = self.fetch("//web_redirect_double_slash", follow_redirects=False) + self.assertEqual(response.code, 301) + self.assertEqual(response.headers['Location'], '/web_redirect_newpath') + def test_header_injection(self): response = self.fetch("/header_injection") self.assertEqual(response.body, b"ok") @@ -1102,6 +1181,15 @@ class StaticFileTest(WebTestCase): response = self.get_and_head('/static/blarg') self.assertEqual(response.code, 404) + def test_path_traversal_protection(self): + with ExpectLog(gen_log, ".*not in root static directory"): + response = self.get_and_head('/static/../static_foo.txt') + # Attempted path traversal should result in 403, not 200 + # (which means the check failed and the file was served) + # or 404 (which means that the file didn't exist and + # is probably a packaging error). + self.assertEqual(response.code, 403) + @wsgi_safe class StaticDefaultFilenameTest(WebTestCase): @@ -1517,6 +1605,22 @@ class ExceptionHandlerTest(SimpleHandlerTestCase): self.assertEqual(response.code, 403) +@wsgi_safe +class BuggyLoggingTest(SimpleHandlerTestCase): + class Handler(RequestHandler): + def get(self): + 1/0 + + def log_exception(self, typ, value, tb): + 1/0 + + def test_buggy_log_exception(self): + # Something gets logged even though the application's + # logger is broken. + with ExpectLog(app_log, '.*'): + self.fetch('/') + + @wsgi_safe class UIMethodUIModuleTest(SimpleHandlerTestCase): """Test that UI methods and modules are created correctly and @@ -1533,6 +1637,7 @@ class UIMethodUIModuleTest(SimpleHandlerTestCase): def my_ui_method(handler, x): return "In my_ui_method(%s) with handler value %s." % ( x, handler.value()) + class MyModule(UIModule): def render(self, x): return "In MyModule(%s) with handler value %s." % ( @@ -1988,8 +2093,10 @@ class StreamingRequestFlowControlTest(WebTestCase): @gen.coroutine def prepare(self): - with self.in_method('prepare'): - yield gen.Task(IOLoop.current().add_callback) + # Note that asynchronous prepare() does not block data_received, + # so we don't use in_method here. + self.methods.append('prepare') + yield gen.Task(IOLoop.current().add_callback) @gen.coroutine def data_received(self, data): @@ -2051,9 +2158,10 @@ class IncorrectContentLengthTest(SimpleHandlerTestCase): # When the content-length is too high, the connection is simply # closed without completing the response. An error is logged on # the server. - with ExpectLog(app_log, "Uncaught exception"): + with ExpectLog(app_log, "(Uncaught exception|Exception in callback)"): with ExpectLog(gen_log, - "Cannot send error response after headers written"): + "(Cannot send error response after headers written" + "|Failed to flush partial response)"): response = self.fetch("/high") self.assertEqual(response.code, 599) self.assertEqual(str(self.server_error), @@ -2063,9 +2171,10 @@ class IncorrectContentLengthTest(SimpleHandlerTestCase): # When the content-length is too low, the connection is closed # without writing the last chunk, so the client never sees the request # complete (which would be a framing error). - with ExpectLog(app_log, "Uncaught exception"): + with ExpectLog(app_log, "(Uncaught exception|Exception in callback)"): with ExpectLog(gen_log, - "Cannot send error response after headers written"): + "(Cannot send error response after headers written" + "|Failed to flush partial response)"): response = self.fetch("/low") self.assertEqual(response.code, 599) self.assertEqual(str(self.server_error), @@ -2075,21 +2184,28 @@ class IncorrectContentLengthTest(SimpleHandlerTestCase): class ClientCloseTest(SimpleHandlerTestCase): class Handler(RequestHandler): def get(self): - # Simulate a connection closed by the client during - # request processing. The client will see an error, but the - # server should respond gracefully (without logging errors - # because we were unable to write out as many bytes as - # Content-Length said we would) - self.request.connection.stream.close() - self.write('hello') + if self.request.version.startswith('HTTP/1'): + # Simulate a connection closed by the client during + # request processing. The client will see an error, but the + # server should respond gracefully (without logging errors + # because we were unable to write out as many bytes as + # Content-Length said we would) + self.request.connection.stream.close() + self.write('hello') + else: + # TODO: add a HTTP2-compatible version of this test. + self.write('requires HTTP/1.x') def test_client_close(self): response = self.fetch('/') + if response.body == b'requires HTTP/1.x': + self.skipTest('requires HTTP/1.x') self.assertEqual(response.code, 599) class SignedValueTest(unittest.TestCase): SECRET = "It's a secret to everybody" + SECRET_DICT = {0: "asdfbasdf", 1: "12312312", 2: "2342342"} def past(self): return self.present() - 86400 * 32 @@ -2151,6 +2267,7 @@ class SignedValueTest(unittest.TestCase): def test_payload_tampering(self): # These cookies are variants of the one in test_known_values. sig = "3d4e60b996ff9c5d5788e333a0cba6f238a22c6c0f94788870e1a9ecd482e152" + def validate(prefix): return (b'value' == decode_signed_value(SignedValueTest.SECRET, "key", @@ -2165,6 +2282,7 @@ class SignedValueTest(unittest.TestCase): def test_signature_tampering(self): prefix = "2|1:0|10:1300000000|3:key|8:dmFsdWU=|" + def validate(sig): return (b'value' == decode_signed_value(SignedValueTest.SECRET, "key", @@ -2194,6 +2312,43 @@ class SignedValueTest(unittest.TestCase): clock=self.present) self.assertEqual(value, decoded) + def test_key_versioning_read_write_default_key(self): + value = b"\xe9" + signed = create_signed_value(SignedValueTest.SECRET_DICT, + "key", value, clock=self.present, + key_version=0) + decoded = decode_signed_value(SignedValueTest.SECRET_DICT, + "key", signed, clock=self.present) + self.assertEqual(value, decoded) + + def test_key_versioning_read_write_non_default_key(self): + value = b"\xe9" + signed = create_signed_value(SignedValueTest.SECRET_DICT, + "key", value, clock=self.present, + key_version=1) + decoded = decode_signed_value(SignedValueTest.SECRET_DICT, + "key", signed, clock=self.present) + self.assertEqual(value, decoded) + + def test_key_versioning_invalid_key(self): + value = b"\xe9" + signed = create_signed_value(SignedValueTest.SECRET_DICT, + "key", value, clock=self.present, + key_version=0) + newkeys = SignedValueTest.SECRET_DICT.copy() + newkeys.pop(0) + decoded = decode_signed_value(newkeys, + "key", signed, clock=self.present) + self.assertEqual(None, decoded) + + def test_key_version_retrieval(self): + value = b"\xe9" + signed = create_signed_value(SignedValueTest.SECRET_DICT, + "key", value, clock=self.present, + key_version=1) + key_version = get_signature_key_version(signed) + self.assertEqual(1, key_version) + @wsgi_safe class XSRFTest(SimpleHandlerTestCase): @@ -2296,7 +2451,7 @@ class XSRFTest(SimpleHandlerTestCase): token2 = self.get_token() # Each token can be used to authenticate its own request. for token in (self.xsrf_token, token2): - response = self.fetch( + response = self.fetch( "/", method="POST", body=urllib_parse.urlencode(dict(_xsrf=token)), headers=self.cookie_headers(token)) @@ -2372,6 +2527,7 @@ class FinishExceptionTest(SimpleHandlerTestCase): self.assertEqual(b'authentication required', response.body) +@wsgi_safe class DecoratorTest(WebTestCase): def get_handlers(self): class RemoveSlashHandler(RequestHandler): @@ -2405,3 +2561,85 @@ class DecoratorTest(WebTestCase): response = self.fetch("/addslash?foo=bar", follow_redirects=False) self.assertEqual(response.code, 301) self.assertEqual(response.headers['Location'], "/addslash/?foo=bar") + + +@wsgi_safe +class CacheTest(WebTestCase): + def get_handlers(self): + class EtagHandler(RequestHandler): + def get(self, computed_etag): + self.write(computed_etag) + + def compute_etag(self): + return self._write_buffer[0] + + return [ + ('/etag/(.*)', EtagHandler) + ] + + def test_wildcard_etag(self): + computed_etag = '"xyzzy"' + etags = '*' + self._test_etag(computed_etag, etags, 304) + + def test_strong_etag_match(self): + computed_etag = '"xyzzy"' + etags = '"xyzzy"' + self._test_etag(computed_etag, etags, 304) + + def test_multiple_strong_etag_match(self): + computed_etag = '"xyzzy1"' + etags = '"xyzzy1", "xyzzy2"' + self._test_etag(computed_etag, etags, 304) + + def test_strong_etag_not_match(self): + computed_etag = '"xyzzy"' + etags = '"xyzzy1"' + self._test_etag(computed_etag, etags, 200) + + def test_multiple_strong_etag_not_match(self): + computed_etag = '"xyzzy"' + etags = '"xyzzy1", "xyzzy2"' + self._test_etag(computed_etag, etags, 200) + + def test_weak_etag_match(self): + computed_etag = '"xyzzy1"' + etags = 'W/"xyzzy1"' + self._test_etag(computed_etag, etags, 304) + + def test_multiple_weak_etag_match(self): + computed_etag = '"xyzzy2"' + etags = 'W/"xyzzy1", W/"xyzzy2"' + self._test_etag(computed_etag, etags, 304) + + def test_weak_etag_not_match(self): + computed_etag = '"xyzzy2"' + etags = 'W/"xyzzy1"' + self._test_etag(computed_etag, etags, 200) + + def test_multiple_weak_etag_not_match(self): + computed_etag = '"xyzzy3"' + etags = 'W/"xyzzy1", W/"xyzzy2"' + self._test_etag(computed_etag, etags, 200) + + def _test_etag(self, computed_etag, etags, status_code): + response = self.fetch( + '/etag/' + computed_etag, + headers={'If-None-Match': etags} + ) + self.assertEqual(response.code, status_code) + + +@wsgi_safe +class RequestSummaryTest(SimpleHandlerTestCase): + class Handler(RequestHandler): + def get(self): + # remote_ip is optional, although it's set by + # both HTTPServer and WSGIAdapter. + # Clobber it to make sure it doesn't break logging. + self.request.remote_ip = None + self.finish(self._request_summary()) + + def test_missing_remote_ip(self): + resp = self.fetch("/") + self.assertEqual(resp.body, b"GET / (None)") diff --git a/tornado/test/websocket_test.py b/tornado/test/websocket_test.py index 7e93d17141a886a9cb06a5f840a11f7e723b9dae..23a4324ce66193ef8ee67c9c16afe2fdfd51033c 100644 --- a/tornado/test/websocket_test.py +++ b/tornado/test/websocket_test.py @@ -12,7 +12,7 @@ from tornado.web import Application, RequestHandler from tornado.util import u try: - import tornado.websocket + import tornado.websocket # noqa from tornado.util import _websocket_mask_python except ImportError: # The unittest module presents misleading errors on ImportError @@ -53,7 +53,7 @@ class EchoHandler(TestWebSocketHandler): class ErrorInOnMessageHandler(TestWebSocketHandler): def on_message(self, message): - 1/0 + 1 / 0 class HeaderHandler(TestWebSocketHandler): @@ -106,6 +106,7 @@ class WebSocketBaseTestCase(AsyncHTTPTestCase): ws.close() yield self.close_future + class WebSocketTest(WebSocketBaseTestCase): def get_app(self): self.close_future = Future() @@ -223,7 +224,11 @@ class WebSocketTest(WebSocketBaseTestCase): self.assertEqual(ws.close_code, 1001) self.assertEqual(ws.close_reason, "goodbye") # The on_close callback is called no matter which side closed. - yield self.close_future + code, reason = yield self.close_future + # The client echoed the close code it received to the server, + # so the server's close code (returned via close_future) is + # the same. + self.assertEqual(code, 1001) @gen_test def test_client_close_reason(self): @@ -250,7 +255,7 @@ class WebSocketTest(WebSocketBaseTestCase): headers = {'Origin': 'http://127.0.0.1:%d' % port} ws = yield websocket_connect(HTTPRequest(url, headers=headers), - io_loop=self.io_loop) + io_loop=self.io_loop) ws.write_message('hello') response = yield ws.read_message() self.assertEqual(response, 'hello') @@ -264,7 +269,7 @@ class WebSocketTest(WebSocketBaseTestCase): headers = {'Origin': 'http://127.0.0.1:%d/something' % port} ws = yield websocket_connect(HTTPRequest(url, headers=headers), - io_loop=self.io_loop) + io_loop=self.io_loop) ws.write_message('hello') response = yield ws.read_message() self.assertEqual(response, 'hello') diff --git a/tornado/testing.py b/tornado/testing.py index 3d3bcf72b974d9cc0f8231982da00fef19766894..93f0dbe14196569f467b015789140e88eae4a8fc 100644 --- a/tornado/testing.py +++ b/tornado/testing.py @@ -417,10 +417,8 @@ class AsyncHTTPSTestCase(AsyncHTTPTestCase): Interface is generally the same as `AsyncHTTPTestCase`. """ def get_http_client(self): - # Some versions of libcurl have deadlock bugs with ssl, - # so always run these tests with SimpleAsyncHTTPClient. - return SimpleAsyncHTTPClient(io_loop=self.io_loop, force_instance=True, - defaults=dict(validate_cert=False)) + return AsyncHTTPClient(io_loop=self.io_loop, force_instance=True, + defaults=dict(validate_cert=False)) def get_httpserver_options(self): return dict(ssl_options=self.get_ssl_options()) diff --git a/tornado/util.py b/tornado/util.py index 34c4b072c49bb83cc80e38e00dcaca16174ab12d..606ced197350ab6a2c6eb32e70d9e933cf6cb3de 100644 --- a/tornado/util.py +++ b/tornado/util.py @@ -78,6 +78,25 @@ class GzipDecompressor(object): return self.decompressobj.flush() +# Fake unicode literal support: Python 3.2 doesn't have the u'' marker for +# literal strings, and alternative solutions like "from __future__ import +# unicode_literals" have other problems (see PEP 414). u() can be applied +# to ascii strings that include \u escapes (but they must not contain +# literal non-ascii characters). +if not isinstance(b'', type('')): + def u(s): + return s + unicode_type = str + basestring_type = str +else: + def u(s): + return s.decode('unicode_escape') + # These names don't exist in py3, so use noqa comments to disable + # warnings in flake8. + unicode_type = unicode # noqa + basestring_type = basestring # noqa + + def import_object(name): """Imports an object by name. @@ -96,6 +115,9 @@ def import_object(name): ... ImportError: No module named missing_module """ + if isinstance(name, unicode_type) and str is not unicode_type: + # On python 2 a byte string is required. + name = name.encode('utf-8') if name.count('.') == 0: return __import__(name, None, None) @@ -107,22 +129,6 @@ def import_object(name): raise ImportError("No module named %s" % parts[-1]) -# Fake unicode literal support: Python 3.2 doesn't have the u'' marker for -# literal strings, and alternative solutions like "from __future__ import -# unicode_literals" have other problems (see PEP 414). u() can be applied -# to ascii strings that include \u escapes (but they must not contain -# literal non-ascii characters). -if type('') is not type(b''): - def u(s): - return s - unicode_type = str - basestring_type = str -else: - def u(s): - return s.decode('unicode_escape') - unicode_type = unicode - basestring_type = basestring - # Deprecated alias that was used before we dropped py25 support. # Left here in case anyone outside Tornado is using it. bytes_type = bytes @@ -192,21 +198,21 @@ class Configurable(object): __impl_class = None __impl_kwargs = None - def __new__(cls, **kwargs): + def __new__(cls, *args, **kwargs): base = cls.configurable_base() - args = {} + init_kwargs = {} if cls is base: impl = cls.configured_class() if base.__impl_kwargs: - args.update(base.__impl_kwargs) + init_kwargs.update(base.__impl_kwargs) else: impl = cls - args.update(kwargs) + init_kwargs.update(kwargs) instance = super(Configurable, cls).__new__(impl) # initialize vs __init__ chosen for compatibility with AsyncHTTPClient # singleton magic. If we get rid of that we can switch to __init__ # here too. - instance.initialize(**args) + instance.initialize(*args, **init_kwargs) return instance @classmethod @@ -227,6 +233,9 @@ class Configurable(object): """Initialize a `Configurable` subclass instance. Configurable classes should use `initialize` instead of ``__init__``. + + .. versionchanged:: 4.2 + Now accepts positional arguments in addition to keyword arguments. """ @classmethod @@ -339,7 +348,7 @@ def _websocket_mask_python(mask, data): return unmasked.tostring() if (os.environ.get('TORNADO_NO_EXTENSION') or - os.environ.get('TORNADO_EXTENSION') == '0'): + os.environ.get('TORNADO_EXTENSION') == '0'): # These environment variables exist to make it easier to do performance # comparisons; they are not guaranteed to remain supported in the future. _websocket_mask = _websocket_mask_python diff --git a/tornado/web.py b/tornado/web.py index 52bfce3663f8d1c585e25cff67c3aafd5f9114ab..9847bb02e93215da78d7d0773b8e8630d57d4631 100644 --- a/tornado/web.py +++ b/tornado/web.py @@ -19,7 +19,9 @@ features that allow it to scale to large numbers of open connections, making it ideal for `long polling <http://en.wikipedia.org/wiki/Push_technology#Long_polling>`_. -Here is a simple "Hello, world" example app:: +Here is a simple "Hello, world" example app: + +.. testcode:: import tornado.ioloop import tornado.web @@ -33,7 +35,11 @@ Here is a simple "Hello, world" example app:: (r"/", MainHandler), ]) application.listen(8888) - tornado.ioloop.IOLoop.instance().start() + tornado.ioloop.IOLoop.current().start() + +.. testoutput:: + :hide: + See the :doc:`guide` for additional information. @@ -50,7 +56,8 @@ request. """ -from __future__ import absolute_import, division, print_function, with_statement +from __future__ import (absolute_import, division, + print_function, with_statement) import base64 @@ -84,7 +91,8 @@ from tornado.log import access_log, app_log, gen_log from tornado import stack_context from tornado import template from tornado.escape import utf8, _unicode -from tornado.util import import_object, ObjectDict, raise_exc_info, unicode_type, _websocket_mask +from tornado.util import (import_object, ObjectDict, raise_exc_info, + unicode_type, _websocket_mask) from tornado.httputil import split_host_and_port @@ -131,18 +139,17 @@ May be overridden by passing a ``version`` keyword argument. DEFAULT_SIGNED_VALUE_MIN_VERSION = 1 """The oldest signed value accepted by `.RequestHandler.get_secure_cookie`. -May be overrided by passing a ``min_version`` keyword argument. +May be overridden by passing a ``min_version`` keyword argument. .. versionadded:: 3.2.1 """ class RequestHandler(object): - """Subclass this class and define `get()` or `post()` to make a handler. + """Base class for HTTP request handlers. - If you want to support more methods than the standard GET/HEAD/POST, you - should override the class variable ``SUPPORTED_METHODS`` in your - `RequestHandler` subclass. + Subclasses must define at least one of the methods defined in the + "Entry points" section below. """ SUPPORTED_METHODS = ("GET", "HEAD", "POST", "DELETE", "PATCH", "PUT", "OPTIONS") @@ -384,6 +391,12 @@ class RequestHandler(object): The returned values are always unicode. """ + + # Make sure `get_arguments` isn't accidentally being called with a + # positional argument that's assumed to be a default (like in + # `get_argument`.) + assert isinstance(strip, bool) + return self._get_arguments(name, self.request.arguments, strip) def get_body_argument(self, name, default=_ARG_DEFAULT, strip=True): @@ -400,7 +413,8 @@ class RequestHandler(object): .. versionadded:: 3.2 """ - return self._get_argument(name, default, self.request.body_arguments, strip) + return self._get_argument(name, default, self.request.body_arguments, + strip) def get_body_arguments(self, name, strip=True): """Returns a list of the body arguments with the given name. @@ -427,7 +441,8 @@ class RequestHandler(object): .. versionadded:: 3.2 """ - return self._get_argument(name, default, self.request.query_arguments, strip) + return self._get_argument(name, default, + self.request.query_arguments, strip) def get_query_arguments(self, name, strip=True): """Returns a list of the query arguments with the given name. @@ -482,7 +497,8 @@ class RequestHandler(object): @property def cookies(self): - """An alias for `self.request.cookies <.httputil.HTTPServerRequest.cookies>`.""" + """An alias for + `self.request.cookies <.httputil.HTTPServerRequest.cookies>`.""" return self.request.cookies def get_cookie(self, name, default=None): @@ -524,6 +540,12 @@ class RequestHandler(object): for k, v in kwargs.items(): if k == 'max_age': k = 'max-age' + + # skip falsy values for httponly and secure flags because + # SimpleCookie sets them regardless + if k in ['httponly', 'secure'] and not v: + continue + morsel[k] = v def clear_cookie(self, name, path="/", domain=None): @@ -590,8 +612,15 @@ class RequestHandler(object): and made it the default. """ self.require_setting("cookie_secret", "secure cookies") - return create_signed_value(self.application.settings["cookie_secret"], - name, value, version=version) + secret = self.application.settings["cookie_secret"] + key_version = None + if isinstance(secret, dict): + if self.application.settings.get("key_version") is None: + raise Exception("key_version setting must be used for secret_key dicts") + key_version = self.application.settings["key_version"] + + return create_signed_value(secret, name, value, version=version, + key_version=key_version) def get_secure_cookie(self, name, value=None, max_age_days=31, min_version=None): @@ -612,6 +641,17 @@ class RequestHandler(object): name, value, max_age_days=max_age_days, min_version=min_version) + def get_secure_cookie_key_version(self, name, value=None): + """Returns the signing key version of the secure cookie. + + The version is returned as int. + """ + self.require_setting("cookie_secret", "secure cookies") + if value is None: + value = self.get_cookie(name) + return get_signature_key_version(value) + + def redirect(self, url, permanent=False, status=None): """Sends a redirect to the given (optionally relative) URL. @@ -627,8 +667,7 @@ class RequestHandler(object): else: assert isinstance(status, int) and 300 <= status <= 399 self.set_status(status) - self.set_header("Location", urlparse.urljoin(utf8(self.request.uri), - utf8(url))) + self.set_header("Location", utf8(url)) self.finish() def write(self, chunk): @@ -648,11 +687,12 @@ class RequestHandler(object): https://github.com/facebook/tornado/issues/1009 """ if self._finished: - raise RuntimeError("Cannot write() after finish(). May be caused " - "by using async operations without the " - "@asynchronous decorator.") + raise RuntimeError("Cannot write() after finish()") if not isinstance(chunk, (bytes, unicode_type, dict)): - raise TypeError("write() only accepts bytes, unicode, and dict objects") + message = "write() only accepts bytes, unicode, and dict objects" + if isinstance(chunk, list): + message += ". Lists not accepted for security reasons; see http://www.tornadoweb.org/en/stable/web.html#tornado.web.RequestHandler.write" + raise TypeError(message) if isinstance(chunk, dict): chunk = escape.json_encode(chunk) self.set_header("Content-Type", "application/json; charset=UTF-8") @@ -785,6 +825,7 @@ class RequestHandler(object): current_user=self.current_user, locale=self.locale, _=self.locale.translate, + pgettext=self.locale.pgettext, static_url=self.static_url, xsrf_form_html=self.xsrf_form_html, reverse_url=self.reverse_url @@ -829,7 +870,8 @@ class RequestHandler(object): for transform in self._transforms: self._status_code, self._headers, chunk = \ transform.transform_first_chunk( - self._status_code, self._headers, chunk, include_footers) + self._status_code, self._headers, + chunk, include_footers) # Ignore the chunk and only write the headers for HEAD requests if self.request.method == "HEAD": chunk = None @@ -860,9 +902,7 @@ class RequestHandler(object): def finish(self, chunk=None): """Finishes this response, ending the HTTP request.""" if self._finished: - raise RuntimeError("finish() called twice. May be caused " - "by using async operations without the " - "@asynchronous decorator.") + raise RuntimeError("finish() called twice") if chunk is not None: self.write(chunk) @@ -914,7 +954,15 @@ class RequestHandler(object): if self._headers_written: gen_log.error("Cannot send error response after headers written") if not self._finished: - self.finish() + # If we get an error between writing headers and finishing, + # we are unlikely to be able to finish due to a + # Content-Length mismatch. Try anyway to release the + # socket. + try: + self.finish() + except Exception: + gen_log.error("Failed to flush partial response", + exc_info=True) return self.clear() @@ -1149,7 +1197,8 @@ class RequestHandler(object): return (version, token, timestamp) except Exception: # Catch exceptions and return nothing instead of failing. - gen_log.debug("Uncaught exception in _decode_xsrf_token", exc_info=True) + gen_log.debug("Uncaught exception in _decode_xsrf_token", + exc_info=True) return None, None, None def check_xsrf_cookie(self): @@ -1289,9 +1338,27 @@ class RequestHandler(object): before completing the request. The ``Etag`` header should be set (perhaps with `set_etag_header`) before calling this method. """ - etag = self._headers.get("Etag") - inm = utf8(self.request.headers.get("If-None-Match", "")) - return bool(etag and inm and inm.find(etag) >= 0) + computed_etag = utf8(self._headers.get("Etag", "")) + # Find all weak and strong etag values from If-None-Match header + # because RFC 7232 allows multiple etag values in a single header. + etags = re.findall( + br'\*|(?:W/)?"[^"]*"', + utf8(self.request.headers.get("If-None-Match", "")) + ) + if not computed_etag or not etags: + return False + + match = False + if etags[0] == b'*': + match = True + else: + # Use a weak comparison when comparing entity-tags. + val = lambda x: x[2:] if x.startswith(b'W/') else x + for etag in etags: + if val(etag) == val(computed_etag): + match = True + break + return match def _stack_context_handle_exception(self, type, value, traceback): try: @@ -1351,7 +1418,10 @@ class RequestHandler(object): if self._auto_finish and not self._finished: self.finish() except Exception as e: - self._handle_request_exception(e) + try: + self._handle_request_exception(e) + except Exception: + app_log.error("Exception in exception handler", exc_info=True) if (self._prepared_future is not None and not self._prepared_future.done()): # In case we failed before setting _prepared_future, do it @@ -1376,8 +1446,8 @@ class RequestHandler(object): self.application.log_request(self) def _request_summary(self): - return self.request.method + " " + self.request.uri + \ - " (" + self.request.remote_ip + ")" + return "%s %s (%s)" % (self.request.method, self.request.uri, + self.request.remote_ip) def _handle_request_exception(self, e): if isinstance(e, Finish): @@ -1385,7 +1455,12 @@ class RequestHandler(object): if not self._finished: self.finish() return - self.log_exception(*sys.exc_info()) + try: + self.log_exception(*sys.exc_info()) + except Exception: + # An error here should still get a best-effort send_error() + # to avoid leaking the connection. + app_log.error("Error in exception logger", exc_info=True) if self._finished: # Extra errors after the request has been finished should # be logged, but there is no reason to continue to try and @@ -1448,10 +1523,11 @@ class RequestHandler(object): def asynchronous(method): """Wrap request handler methods with this if they are asynchronous. - This decorator is unnecessary if the method is also decorated with - ``@gen.coroutine`` (it is legal but unnecessary to use the two - decorators together, in which case ``@asynchronous`` must be - first). + This decorator is for callback-style asynchronous methods; for + coroutines, use the ``@gen.coroutine`` decorator without + ``@asynchronous``. (It is legal for legacy reasons to use the two + decorators together provided ``@asynchronous`` is first, but + ``@asynchronous`` will be ignored in this case) This decorator should only be applied to the :ref:`HTTP verb methods <verbs>`; its behavior is undefined for any other method. @@ -1464,10 +1540,12 @@ def asynchronous(method): method returns. It is up to the request handler to call `self.finish() <RequestHandler.finish>` to finish the HTTP request. Without this decorator, the request is automatically - finished when the ``get()`` or ``post()`` method returns. Example:: + finished when the ``get()`` or ``post()`` method returns. Example: - class MyRequestHandler(web.RequestHandler): - @web.asynchronous + .. testcode:: + + class MyRequestHandler(RequestHandler): + @asynchronous def get(self): http = httpclient.AsyncHTTPClient() http.fetch("http://friendfeed.com/", self._on_download) @@ -1476,11 +1554,16 @@ def asynchronous(method): self.write("Downloaded!") self.finish() + .. testoutput:: + :hide: + .. versionadded:: 3.1 The ability to use ``@gen.coroutine`` without ``@asynchronous``. + """ # Delay the IOLoop import because it's not available on app engine. from tornado.ioloop import IOLoop + @functools.wraps(method) def wrapper(self, *args, **kwargs): self._auto_finish = False @@ -1598,7 +1681,7 @@ class Application(httputil.HTTPServerConnectionDelegate): ]) http_server = httpserver.HTTPServer(application) http_server.listen(8080) - ioloop.IOLoop.instance().start() + ioloop.IOLoop.current().start() The constructor for this class takes in a list of `URLSpec` objects or (regexp, request_class) tuples. When we receive requests, we @@ -1696,7 +1779,7 @@ class Application(httputil.HTTPServerConnectionDelegate): `.TCPServer.bind`/`.TCPServer.start` methods directly. Note that after calling this method you still need to call - ``IOLoop.instance().start()`` to start the server. + ``IOLoop.current().start()`` to start the server. """ # import is here rather than top level because HTTPServer # is not importable on appengine @@ -1838,7 +1921,8 @@ class _RequestDispatcher(httputil.HTTPMessageDelegate): def headers_received(self, start_line, headers): self.set_request(httputil.HTTPServerRequest( - connection=self.connection, start_line=start_line, headers=headers)) + connection=self.connection, start_line=start_line, + headers=headers)) if self.stream_request_body: self.request.body = Future() return self.execute() @@ -1855,7 +1939,9 @@ class _RequestDispatcher(httputil.HTTPMessageDelegate): handlers = app._get_host_handlers(self.request) if not handlers: self.handler_class = RedirectHandler - self.handler_kwargs = dict(url="%s://%s/" % (self.request.protocol, app.default_host)) + self.handler_kwargs = dict(url="%s://%s/" + % (self.request.protocol, + app.default_host)) return for spec in handlers: match = spec.regex.match(self.request.path) @@ -1921,13 +2007,14 @@ class _RequestDispatcher(httputil.HTTPMessageDelegate): if self.stream_request_body: self.handler._prepared_future = Future() # Note that if an exception escapes handler._execute it will be - # trapped in the Future it returns (which we are ignoring here). + # trapped in the Future it returns (which we are ignoring here, + # leaving it to be logged when the Future is GC'd). # However, that shouldn't happen because _execute has a blanket # except handler, and we cannot easily access the IOLoop here to # call add_future (because of the requirement to remain compatible # with WSGI) - f = self.handler._execute(transforms, *self.path_args, **self.path_kwargs) - f.add_done_callback(lambda f: f.exception()) + f = self.handler._execute(transforms, *self.path_args, + **self.path_kwargs) # If we are streaming the request body, then execute() is finished # when the handler has prepared to receive the body. If not, # it doesn't matter when execute() finishes (so we return None) @@ -1961,6 +2048,8 @@ class HTTPError(Exception): self.log_message = log_message self.args = args self.reason = kwargs.get('reason', None) + if log_message and not args: + self.log_message = log_message.replace('%', '%%') def __str__(self): message = "HTTP %d: %s" % ( @@ -2221,7 +2310,8 @@ class StaticFileHandler(RequestHandler): if content_type: self.set_header("Content-Type", content_type) - cache_time = self.get_cache_time(self.path, self.modified, content_type) + cache_time = self.get_cache_time(self.path, self.modified, + content_type) if cache_time > 0: self.set_header("Expires", datetime.datetime.utcnow() + datetime.timedelta(seconds=cache_time)) @@ -2286,9 +2376,13 @@ class StaticFileHandler(RequestHandler): .. versionadded:: 3.1 """ - root = os.path.abspath(root) - # os.path.abspath strips a trailing / - # it needs to be temporarily added back for requests to root/ + # os.path.abspath strips a trailing /. + # We must add it back to `root` so that we only match files + # in a directory named `root` instead of files starting with + # that prefix. + root = os.path.abspath(root) + os.path.sep + # The trailing slash also needs to be temporarily added back + # the requested path so a request to root/ will match. if not (absolute_path + os.path.sep).startswith(root): raise HTTPError(403, "%s is not in root static directory", self.path) @@ -2390,7 +2484,8 @@ class StaticFileHandler(RequestHandler): .. versionadded:: 3.1 """ stat_result = self._stat() - modified = datetime.datetime.utcfromtimestamp(stat_result[stat.ST_MTIME]) + modified = datetime.datetime.utcfromtimestamp( + stat_result[stat.ST_MTIME]) return modified def get_content_type(self): @@ -2651,7 +2746,8 @@ class UIModule(object): raise NotImplementedError() def embedded_javascript(self): - """Override to return a JavaScript string to be embedded in the page.""" + """Override to return a JavaScript string + to be embedded in the page.""" return None def javascript_files(self): @@ -2663,7 +2759,8 @@ class UIModule(object): return None def embedded_css(self): - """Override to return a CSS string that will be embedded in the page.""" + """Override to return a CSS string + that will be embedded in the page.""" return None def css_files(self): @@ -2885,11 +2982,13 @@ else: return result == 0 -def create_signed_value(secret, name, value, version=None, clock=None): +def create_signed_value(secret, name, value, version=None, clock=None, + key_version=None): if version is None: version = DEFAULT_SIGNED_VALUE_VERSION if clock is None: clock = time.time + timestamp = utf8(str(int(clock()))) value = base64.b64encode(utf8(value)) if version == 1: @@ -2906,7 +3005,7 @@ def create_signed_value(secret, name, value, version=None, clock=None): # # The fields are: # - format version (i.e. 2; no length prefix) - # - key version (currently 0; reserved for future key rotation features) + # - key version (integer, default is 0) # - timestamp (integer seconds since epoch) # - name (not encoded; assumed to be ~alphanumeric) # - value (base64-encoded) @@ -2914,34 +3013,32 @@ def create_signed_value(secret, name, value, version=None, clock=None): def format_field(s): return utf8("%d:" % len(s)) + utf8(s) to_sign = b"|".join([ - b"2|1:0", + b"2", + format_field(str(key_version or 0)), format_field(timestamp), format_field(name), format_field(value), b'']) + + if isinstance(secret, dict): + assert key_version is not None, 'Key version must be set when sign key dict is used' + assert version >= 2, 'Version must be at least 2 for key version support' + secret = secret[key_version] + signature = _create_signature_v2(secret, to_sign) return to_sign + signature else: raise ValueError("Unsupported version %d" % version) -# A leading version number in decimal with no leading zeros, followed by a pipe. +# A leading version number in decimal +# with no leading zeros, followed by a pipe. _signed_value_version_re = re.compile(br"^([1-9][0-9]*)\|(.*)$") -def decode_signed_value(secret, name, value, max_age_days=31, clock=None, min_version=None): - if clock is None: - clock = time.time - if min_version is None: - min_version = DEFAULT_SIGNED_VALUE_MIN_VERSION - if min_version > 2: - raise ValueError("Unsupported min_version %d" % min_version) - if not value: - return None - - # Figure out what version this is. Version 1 did not include an +def _get_version(value): + # Figures out what version value is. Version 1 did not include an # explicit version field and started with arbitrary base64 data, # which makes this tricky. - value = utf8(value) m = _signed_value_version_re.match(value) if m is None: version = 1 @@ -2958,13 +3055,31 @@ def decode_signed_value(secret, name, value, max_age_days=31, clock=None, min_ve version = 1 except ValueError: version = 1 + return version + + +def decode_signed_value(secret, name, value, max_age_days=31, + clock=None, min_version=None): + if clock is None: + clock = time.time + if min_version is None: + min_version = DEFAULT_SIGNED_VALUE_MIN_VERSION + if min_version > 2: + raise ValueError("Unsupported min_version %d" % min_version) + if not value: + return None + + value = utf8(value) + version = _get_version(value) if version < min_version: return None if version == 1: - return _decode_signed_value_v1(secret, name, value, max_age_days, clock) + return _decode_signed_value_v1(secret, name, value, + max_age_days, clock) elif version == 2: - return _decode_signed_value_v2(secret, name, value, max_age_days, clock) + return _decode_signed_value_v2(secret, name, value, + max_age_days, clock) else: return None @@ -2987,7 +3102,8 @@ def _decode_signed_value_v1(secret, name, value, max_age_days, clock): # digits from the payload to the timestamp without altering the # signature. For backwards compatibility, sanity-check timestamp # here instead of modifying _cookie_signature. - gen_log.warning("Cookie timestamp in future; possible tampering %r", value) + gen_log.warning("Cookie timestamp in future; possible tampering %r", + value) return None if parts[1].startswith(b"0"): gen_log.warning("Tampered cookie %r", value) @@ -2998,7 +3114,7 @@ def _decode_signed_value_v1(secret, name, value, max_age_days, clock): return None -def _decode_signed_value_v2(secret, name, value, max_age_days, clock): +def _decode_fields_v2(value): def _consume_field(s): length, _, rest = s.partition(b':') n = int(length) @@ -3009,16 +3125,28 @@ def _decode_signed_value_v2(secret, name, value, max_age_days, clock): raise ValueError("malformed v2 signed value field") rest = rest[n + 1:] return field_value, rest + rest = value[2:] # remove version number + key_version, rest = _consume_field(rest) + timestamp, rest = _consume_field(rest) + name_field, rest = _consume_field(rest) + value_field, passed_sig = _consume_field(rest) + return int(key_version), timestamp, name_field, value_field, passed_sig + + +def _decode_signed_value_v2(secret, name, value, max_age_days, clock): try: - key_version, rest = _consume_field(rest) - timestamp, rest = _consume_field(rest) - name_field, rest = _consume_field(rest) - value_field, rest = _consume_field(rest) + key_version, timestamp, name_field, value_field, passed_sig = _decode_fields_v2(value) except ValueError: return None - passed_sig = rest signed_string = value[:-len(passed_sig)] + + if isinstance(secret, dict): + try: + secret = secret[key_version] + except KeyError: + return None + expected_sig = _create_signature_v2(secret, signed_string) if not _time_independent_equals(passed_sig, expected_sig): return None @@ -3034,6 +3162,19 @@ def _decode_signed_value_v2(secret, name, value, max_age_days, clock): return None +def get_signature_key_version(value): + value = utf8(value) + version = _get_version(value) + if version < 2: + return None + try: + key_version, _, _, _, _ = _decode_fields_v2(value) + except ValueError: + return None + + return key_version + + def _create_signature_v1(secret, *parts): hash = hmac.new(utf8(secret), digestmod=hashlib.sha1) for part in parts: diff --git a/tornado/websocket.py b/tornado/websocket.py index c009225ce51decfd59ea964855648cbe10babe92..2f57b99093ac76a9a545c676dd8550bbf66ae626 100644 --- a/tornado/websocket.py +++ b/tornado/websocket.py @@ -16,7 +16,8 @@ the protocol (known as "draft 76") and are not compatible with this module. Removed support for the draft 76 protocol version. """ -from __future__ import absolute_import, division, print_function, with_statement +from __future__ import (absolute_import, division, + print_function, with_statement) # Author: Jacob Kristhammar, 2010 import base64 @@ -39,9 +40,9 @@ from tornado.tcpclient import TCPClient from tornado.util import _websocket_mask try: - from urllib.parse import urlparse # py2 + from urllib.parse import urlparse # py2 except ImportError: - from urlparse import urlparse # py3 + from urlparse import urlparse # py3 try: xrange # py2 @@ -74,17 +75,22 @@ class WebSocketHandler(tornado.web.RequestHandler): http://tools.ietf.org/html/rfc6455. Here is an example WebSocket handler that echos back all received messages - back to the client:: + back to the client: - class EchoWebSocket(websocket.WebSocketHandler): + .. testcode:: + + class EchoWebSocket(tornado.websocket.WebSocketHandler): def open(self): - print "WebSocket opened" + print("WebSocket opened") def on_message(self, message): self.write_message(u"You said: " + message) def on_close(self): - print "WebSocket closed" + print("WebSocket closed") + + .. testoutput:: + :hide: WebSockets are not standard HTTP connections. The "handshake" is HTTP, but after the handshake, the protocol is @@ -139,16 +145,22 @@ class WebSocketHandler(tornado.web.RequestHandler): # Upgrade header should be present and should be equal to WebSocket if self.request.headers.get("Upgrade", "").lower() != 'websocket': self.set_status(400) - self.finish("Can \"Upgrade\" only to \"WebSocket\".") + log_msg = "Can \"Upgrade\" only to \"WebSocket\"." + self.finish(log_msg) + gen_log.debug(log_msg) return - # Connection header should be upgrade. Some proxy servers/load balancers + # Connection header should be upgrade. + # Some proxy servers/load balancers # might mess with it. headers = self.request.headers - connection = map(lambda s: s.strip().lower(), headers.get("Connection", "").split(",")) + connection = map(lambda s: s.strip().lower(), + headers.get("Connection", "").split(",")) if 'upgrade' not in connection: self.set_status(400) - self.finish("\"Connection\" must be \"Upgrade\".") + log_msg = "\"Connection\" must be \"Upgrade\"." + self.finish(log_msg) + gen_log.debug(log_msg) return # Handle WebSocket Origin naming convention differences @@ -160,13 +172,14 @@ class WebSocketHandler(tornado.web.RequestHandler): else: origin = self.request.headers.get("Sec-Websocket-Origin", None) - # If there was an origin header, check to make sure it matches # according to check_origin. When the origin is None, we assume it # did not come from a browser and that it can be passed on. if origin is not None and not self.check_origin(origin): self.set_status(403) - self.finish("Cross origin websockets not allowed") + log_msg = "Cross origin websockets not allowed" + self.finish(log_msg) + gen_log.debug(log_msg) return self.stream = self.request.connection.detach() @@ -350,7 +363,7 @@ class WebSocketHandler(tornado.web.RequestHandler): self.ws_connection.on_connection_close() self.ws_connection = None if not self._on_close_called: - self._on_close_called + self._on_close_called = True self.on_close() def send_error(self, *args, **kwargs): @@ -507,7 +520,8 @@ class WebSocketProtocol13(WebSocketProtocol): self._handle_websocket_headers() self._accept_connection() except ValueError: - gen_log.debug("Malformed WebSocket request received", exc_info=True) + gen_log.debug("Malformed WebSocket request received", + exc_info=True) self._abort() return @@ -543,18 +557,19 @@ class WebSocketProtocol13(WebSocketProtocol): selected = self.handler.select_subprotocol(subprotocols) if selected: assert selected in subprotocols - subprotocol_header = "Sec-WebSocket-Protocol: %s\r\n" % selected + subprotocol_header = ("Sec-WebSocket-Protocol: %s\r\n" + % selected) extension_header = '' extensions = self._parse_extensions_header(self.request.headers) for ext in extensions: if (ext[0] == 'permessage-deflate' and - self._compression_options is not None): + self._compression_options is not None): # TODO: negotiate parameters if compression_options # specifies limits. self._create_compressors('server', ext[1]) if ('client_max_window_bits' in ext[1] and - ext[1]['client_max_window_bits'] is None): + ext[1]['client_max_window_bits'] is None): # Don't echo an offered client_max_window_bits # parameter with no value. del ext[1]['client_max_window_bits'] @@ -599,7 +614,7 @@ class WebSocketProtocol13(WebSocketProtocol): extensions = self._parse_extensions_header(headers) for ext in extensions: if (ext[0] == 'permessage-deflate' and - self._compression_options is not None): + self._compression_options is not None): self._create_compressors('client', ext[1]) else: raise ValueError("unsupported extension %r", ext) @@ -711,7 +726,8 @@ class WebSocketProtocol13(WebSocketProtocol): if self._masked_frame: self.stream.read_bytes(4, self._on_masking_key) else: - self.stream.read_bytes(self._frame_length, self._on_frame_data) + self.stream.read_bytes(self._frame_length, + self._on_frame_data) elif payloadlen == 126: self.stream.read_bytes(2, self._on_frame_length_16) elif payloadlen == 127: @@ -745,7 +761,8 @@ class WebSocketProtocol13(WebSocketProtocol): self._wire_bytes_in += len(data) self._frame_mask = data try: - self.stream.read_bytes(self._frame_length, self._on_masked_frame_data) + self.stream.read_bytes(self._frame_length, + self._on_masked_frame_data) except StreamClosedError: self._abort() @@ -818,7 +835,8 @@ class WebSocketProtocol13(WebSocketProtocol): self.handler.close_code = struct.unpack('>H', data[:2])[0] if len(data) > 2: self.handler.close_reason = to_unicode(data[2:]) - self.close() + # Echo the received close code, if any (RFC 6455 section 5.5.1). + self.close(self.handler.close_code) elif opcode == 0x9: # Ping self._write_frame(True, 0xA, data) @@ -869,6 +887,7 @@ class WebSocketClientConnection(simple_httpclient._HTTPConnection): self.read_queue = collections.deque() self.key = base64.b64encode(os.urandom(16)) self._on_message_callback = on_message_callback + self.close_code = self.close_reason = None scheme, sep, rest = request.url.partition(':') scheme = {'ws': 'http', 'wss': 'https'}[scheme] @@ -891,7 +910,7 @@ class WebSocketClientConnection(simple_httpclient._HTTPConnection): self.tcp_client = TCPClient(io_loop=io_loop) super(WebSocketClientConnection, self).__init__( io_loop, None, request, lambda: None, self._on_http_response, - 104857600, self.tcp_client, 65536) + 104857600, self.tcp_client, 65536, 104857600) def close(self, code=None, reason=None): """Closes the websocket connection. diff --git a/tornado/wsgi.py b/tornado/wsgi.py index e7e07fbc9cf3e960eebe66b8263f198476a1af35..59e6c559f1289695ffc99627ee65ab3398562e5c 100644 --- a/tornado/wsgi.py +++ b/tornado/wsgi.py @@ -253,7 +253,7 @@ class WSGIContainer(object): container = tornado.wsgi.WSGIContainer(simple_app) http_server = tornado.httpserver.HTTPServer(container) http_server.listen(8888) - tornado.ioloop.IOLoop.instance().start() + tornado.ioloop.IOLoop.current().start() This class is intended to let other frameworks (Django, web.py, etc) run on the Tornado HTTP server and I/O loop. @@ -284,7 +284,8 @@ class WSGIContainer(object): if not data: raise Exception("WSGI app did not call start_response") - status_code = int(data["status"].split()[0]) + status_code, reason = data["status"].split(' ', 1) + status_code = int(status_code) headers = data["headers"] header_set = set(k.lower() for (k, v) in headers) body = escape.utf8(body) @@ -296,13 +297,12 @@ class WSGIContainer(object): if "server" not in header_set: headers.append(("Server", "TornadoServer/%s" % tornado.version)) - parts = [escape.utf8("HTTP/1.1 " + data["status"] + "\r\n")] + start_line = httputil.ResponseStartLine("HTTP/1.1", status_code, reason) + header_obj = httputil.HTTPHeaders() for key, value in headers: - parts.append(escape.utf8(key) + b": " + escape.utf8(value) + b"\r\n") - parts.append(b"\r\n") - parts.append(body) - request.write(b"".join(parts)) - request.finish() + header_obj.add(key, value) + request.connection.write_headers(start_line, header_obj, chunk=body) + request.connection.finish() self._log(status_code, request) @staticmethod