diff --git a/.travis.yml b/.travis.yml index 32a67c0729f66909541c450affb1fc2916b521d7..2799cf2f9fe24ec3a2551150c6035202a36e3de4 100644 --- a/.travis.yml +++ b/.travis.yml @@ -10,6 +10,10 @@ branches: install: - pip install cheetah pyopenssl==0.13.1 +cache: + directories: + - $HOME/.cache/pip + before_script: - chmod +x ./tests/all_tests.py diff --git a/gui/slick/interfaces/default/editShow.tmpl b/gui/slick/interfaces/default/editShow.tmpl index a73698500eb8399009e99eba5cd23f6908e3ee90..3f959330ee025f79605b30f7a6d9640a31a7a1dd 100644 --- a/gui/slick/interfaces/default/editShow.tmpl +++ b/gui/slick/interfaces/default/editShow.tmpl @@ -76,7 +76,7 @@ This will <b>affect the episode show search</b> on nzb and torrent provider.<br <b>Info Language:</b><br /> (this will only affect the language of the retrieved metadata file contents and episode filenames)<br /> -<select name="indexerLang" id="indexerLangSelect" class="form-control form-control-inline input-sm bfh-languages" data-language="#echo $sickbeard.INDEXER_DEFAULT_LANGUAGE#" data-available="#echo ','.join($sickbeard.indexerApi().config['valid_languages'])#"></select><br /> +<select name="indexerLang" id="indexerLangSelect" class="form-control form-control-inline input-sm bfh-languages" data-language="#echo $show.lang#" data-available="#echo ','.join($sickbeard.indexerApi().config['valid_languages'])#"></select><br /> <br /> <b>Flatten files (no folders): </b> <input type="checkbox" name="flatten_folders" #if $show.flatten_folders == 1 and not $sickbeard.NAMING_FORCE_FOLDERS then "checked=\"checked\"" else ""# #if $sickbeard.NAMING_FORCE_FOLDERS then "disabled=\"disabled\"" else ""#/><br /> diff --git a/gui/slick/interfaces/default/history.tmpl b/gui/slick/interfaces/default/history.tmpl index 1d31ec0b5129afed2506d24393458d2fc106e6ed..2123adea52b91ad1c39ce4feeef528e71bd9ab78 100644 --- a/gui/slick/interfaces/default/history.tmpl +++ b/gui/slick/interfaces/default/history.tmpl @@ -213,16 +213,10 @@ #for $action in sorted($hItem["actions"]): #set $curStatus, $curQuality = $Quality.splitCompositeStatus(int($action["action"])) #if $curStatus == DOWNLOADED: - #set $match = $re.search("\-(\w+)\.\w{3}\Z", $os.path.basename($action["resource"])) - #if $match - #if $match.group(1).upper() in ("X264", "720P"): - #set $match = $re.search("(\w+)\-.*\-"+$match.group(1)+"\.\w{3}\Z", $os.path.basename($hItem["resource"]), re.IGNORECASE) - #if $match - <span style="cursor: help;" title="$os.path.basename($action["resource"])"><i>$match.group(1).upper()</i></span> - #end if - #else: - <span style="cursor: help;" title="$os.path.basename($action["resource"])"><i>$match.group(1).upper()</i></span> - #end if + #if $action["provider"] != "-1": + <span style="cursor: help;" title="$os.path.basename($action["resource"])"><i>$action["provider"]</i></span> + #else: + <span style="cursor: help;" title="$os.path.basename($action["resource"])"></span> #end if #end if #end for diff --git a/lib/trakt/trakt.py b/lib/trakt/trakt.py index 98c5ed157ffb42594d9a19a18b996d8ee71f2bac..8ce2e072842aa0aec8fbeec3aa7cfa423f5ba530 100644 --- a/lib/trakt/trakt.py +++ b/lib/trakt/trakt.py @@ -74,9 +74,12 @@ class TraktAPI(): if None == headers: headers = self.headers - if not None == sickbeard.TRAKT_ACCESS_TOKEN: - headers['Authorization'] = 'Bearer ' + sickbeard.TRAKT_ACCESS_TOKEN - + if None == sickbeard.TRAKT_ACCESS_TOKEN: + logger.log(u'You must get a Trakt TOKEN. Check your Trakt settings', logger.WARNING) + raise traktAuthException(e) + + headers['Authorization'] = 'Bearer ' + sickbeard.TRAKT_ACCESS_TOKEN + try: resp = requests.request(method, url, headers=headers, timeout=self.timeout, data=json.dumps(data) if data else [], verify=self.verify) diff --git a/sickbeard/common.py b/sickbeard/common.py index 2786762a591646858687ad7f73a74d1aae404182..bccb2acf60c9eaa5b84a3453cda6b558acf08cb9 100644 --- a/sickbeard/common.py +++ b/sickbeard/common.py @@ -23,10 +23,7 @@ import re import uuid INSTANCE_ID = str(uuid.uuid1()) -#Use Sick Beard USER_AGENT until they stop throttling us, -#newznab searching has long been fixed, but we now limit it to 400 results just as they do. -#USER_AGENT = ('SickRage/(' + platform.system() + '; ' + platform.release() + '; ' + INSTANCE_ID + ')') -USER_AGENT = 'Sick Beard/alpha2-master' + ' (' + platform.system() + ' ' + platform.release() + ')' +USER_AGENT = ('SickRage/(' + platform.system() + '; ' + platform.release() + '; ' + INSTANCE_ID + ')') mediaExtensions = ['avi', 'mkv', 'mpg', 'mpeg', 'wmv', 'ogm', 'mp4', 'iso', 'img', 'divx', @@ -36,9 +33,9 @@ mediaExtensions = ['avi', 'mkv', 'mpg', 'mpeg', 'wmv', subtitleExtensions = ['srt', 'sub', 'ass', 'idx', 'ssa'] -cpu_presets = {'HIGH': 0.1, - 'NORMAL': 0.05, - 'LOW': 0.01 +cpu_presets = {'HIGH': 5, + 'NORMAL': 2, + 'LOW': 1 } ### Other constants diff --git a/sickbeard/helpers.py b/sickbeard/helpers.py index 1d76f79847c08e07e060b6d60251b401773a62ee..de0087f0d6d5ad71365e27f1a16090620be0390c 100644 --- a/sickbeard/helpers.py +++ b/sickbeard/helpers.py @@ -137,10 +137,7 @@ def remove_non_release_groups(name): elif remove_type == 'searchre': _name = re.sub(r'(?i)' + remove_string, '', _name) - #if _name != name: - # logger.log(u'Change title from {old_name} to {new_name}'.format(old_name=name, new_name=_name), logger.DEBUG) - - return _name + return _name.strip('.- ') def replaceExtension(filename, newExt): @@ -163,6 +160,9 @@ def replaceExtension(filename, newExt): return sepFile[0] + "." + newExt +def notTorNZBFile(filename): + return not (filename.endswith(".torrent") or filename.endswith(".nzb")) + def isSyncFile(filename): extension = filename.rpartition(".")[2].lower() #if extension == '!sync' or extension == 'lftp-pget-status' or extension == 'part' or extension == 'bts': @@ -1374,13 +1374,17 @@ def download_file(url, filename, session=None): resp.status_code) + ': ' + codeDescription(resp.status_code), logger.DEBUG) return False - with open(filename, 'wb') as fp: - for chunk in resp.iter_content(chunk_size=1024): - if chunk: - fp.write(chunk) - fp.flush() + try: + with open(filename, 'wb') as fp: + for chunk in resp.iter_content(chunk_size=1024): + if chunk: + fp.write(chunk) + fp.flush() + + chmodAsParent(filename) + except: + logger.log(u"Problem setting permissions or writing file to: %s" % filename, logger.WARNING) - chmodAsParent(filename) except requests.exceptions.HTTPError, e: _remove_file_failed(filename) logger.log(u"HTTP error " + str(e.errno) + " while loading URL " + url, logger.WARNING) diff --git a/sickbeard/name_cache.py b/sickbeard/name_cache.py index 17c49eb97536f9fad9c7ec7308ce01073a552337..4ebfcab907366823f5b7662ac7ec55e251bc7477 100644 --- a/sickbeard/name_cache.py +++ b/sickbeard/name_cache.py @@ -55,7 +55,7 @@ def retrieveNameFromCache(name): if name in nameCache: return int(nameCache[name]) -def clearCache(): +def clearCache(indexerid=0): """ Deletes all "unknown" entries from the cache (names with indexer_id of 0). """ @@ -66,9 +66,9 @@ def clearCache(): nameCache = {} cacheDB = db.DBConnection('cache.db') - cacheDB.action("DELETE FROM scene_names WHERE indexer_id = ?", [0]) + cacheDB.action("DELETE FROM scene_names WHERE indexer_id = ? OR indexer_id = ?", (indexerid, 0)) - toRemove = [key for key, value in nameCache.iteritems() if value == 0] + toRemove = [key for key, value in nameCache.iteritems() if value == 0 or value == indexerid] for key in toRemove: del nameCache[key] @@ -83,39 +83,16 @@ def saveNameCacheToDb(): def buildNameCache(show=None): global nameCache - with nameCacheLock: - # clear internal name cache - clearCache() - - # update scene exception names - sickbeard.scene_exceptions.retrieve_exceptions() - - if not show: - logger.log(u"Building internal name cache for all shows", logger.INFO) - - cacheDB = db.DBConnection('cache.db') - cache_results = cacheDB.select("SELECT * FROM scene_names") - for cache_result in cache_results: - name = sickbeard.helpers.full_sanitizeSceneName(cache_result["name"]) - if name in nameCache: - continue - - indexer_id = int(cache_result["indexer_id"]) - nameCache[name] = indexer_id - - for show in sickbeard.showList: - for curSeason in [-1] + sickbeard.scene_exceptions.get_scene_seasons(show.indexerid): - for name in list(set( - sickbeard.scene_exceptions.get_scene_exceptions(show.indexerid, season=curSeason) + [ - show.name])): - name = sickbeard.helpers.full_sanitizeSceneName(name) - if name in nameCache: - continue - - nameCache[name] = int(show.indexerid) - else: - logger.log(u"Building internal name cache for " + show.name, logger.INFO) + sickbeard.scene_exceptions.retrieve_exceptions() + if not show: + logger.log(u"Building internal name cache for all shows", logger.INFO) + for show in sickbeard.showList: + buildNameCache(show) + else: + with nameCacheLock: + logger.log(u"Building internal name cache for " + show.name, logger.INFO) + clearCache(show.indexerid) for curSeason in [-1] + sickbeard.scene_exceptions.get_scene_seasons(show.indexerid): for name in list(set(sickbeard.scene_exceptions.get_scene_exceptions(show.indexerid, season=curSeason) + [ show.name])): @@ -124,5 +101,4 @@ def buildNameCache(show=None): continue nameCache[name] = int(show.indexerid) - - logger.log(u"Internal name cache set to: " + str(nameCache), logger.DEBUG) \ No newline at end of file + logger.log(u"Internal name cache for " + show.name + " set to: [ " + u', '.join([key for key, value in nameCache.iteritems() if value == show.indexerid]) +" ]" , logger.DEBUG) diff --git a/sickbeard/postProcessor.py b/sickbeard/postProcessor.py index fc5df375813ebd4f1e9efcb50c5043d3d61c1ef0..2db0681627ac64e88398c9b2a2a2e90e8cd620c9 100644 --- a/sickbeard/postProcessor.py +++ b/sickbeard/postProcessor.py @@ -184,7 +184,7 @@ class PostProcessor(object): base_name = re.sub(r'[\[\]\*\?]', r'[\g<0>]', base_name) if subfolders: # subfolders are only checked in show folder, so names will always be exactly alike - filelist = ek.ek(recursive_glob, ek.ek(os.path.dirname, file_path), base_name + '*') # just create the list of all files starting with the basename + filelist = ek.ek(recursive_glob, ek.ek(os.path.dirname, file_path), base_name + '*') # just create the list of all files starting with the basename else: # this is called when PP, so we need to do the filename check case-insensitive filelist = [] checklist = ek.ek(glob.glob, ek.ek(os.path.join, ek.ek(os.path.dirname, file_path), '*')) # get a list of all the files in the folder diff --git a/sickbeard/processTV.py b/sickbeard/processTV.py index 68c1c99779b667fec3f8234231d398187e61a177..a1d49d71bca084db46160cfc7f3608486696aca7 100644 --- a/sickbeard/processTV.py +++ b/sickbeard/processTV.py @@ -155,12 +155,13 @@ def processDir(dirName, nzbName=None, process_method=None, force=False, is_prior path, dirs, files = get_path_dir_files(dirName, nzbName, type) + files = filter(helpers.notTorNZBFile, files) SyncFiles = filter(helpers.isSyncFile, files) # Don't post process if files are still being synced and option is activated if SyncFiles and sickbeard.POSTPONE_IF_SYNC_FILES: postpone = True - + nzbNameOriginal = nzbName if not postpone: diff --git a/sickbeard/properFinder.py b/sickbeard/properFinder.py index 4a0a2d45bf33d960d201e8bf087824787bd3fed4..29d3390931754cf21757fee85e76288aba836488 100644 --- a/sickbeard/properFinder.py +++ b/sickbeard/properFinder.py @@ -243,6 +243,7 @@ class ProperFinder(): # snatch it search.snatchEpisode(result, SNATCHED_PROPER) + time.sleep(cpu_presets[sickbeard.CPU_PRESET]) def _genericName(self, name): return name.replace(".", " ").replace("-", " ").replace("_", " ").lower() diff --git a/sickbeard/providers/generic.py b/sickbeard/providers/generic.py index 7b6e8ca3ab3b6b148daa0a285d59aa3ee3369f51..c836c49d0b4d9b5565396e5ac396681454c0ba35 100644 --- a/sickbeard/providers/generic.py +++ b/sickbeard/providers/generic.py @@ -174,20 +174,21 @@ class GenericProvider: return for url in urls: + logger.log(u"Downloading a result from " + self.name + " at " + url) if helpers.download_file(url, filename, session=self.session): - logger.log(u"Downloading a result from " + self.name + " at " + url) - - if self.providerType == GenericProvider.TORRENT: - logger.log(u"Saved magnet link to " + filename, logger.INFO) - else: - logger.log(u"Saved result to " + filename, logger.INFO) - if self._verify_download(filename): + if self.providerType == GenericProvider.TORRENT: + logger.log(u"Saved magnet link to " + filename, logger.INFO) + else: + logger.log(u"Saved result to " + filename, logger.INFO) return True else: + logger.log(u"Could not download %s" % url, logger.WARNING) helpers._remove_file_failed(filename) - logger.log(u"Failed to download result", logger.WARNING) + if len(urls): + logger.log(u"Failed to download any results", logger.WARNING) + return False def _verify_download(self, file_name=None): diff --git a/sickbeard/providers/newznab.py b/sickbeard/providers/newznab.py index 96696d1bc1e436543b9ec7b6a2673befef9be7af..41ce16c4aa3f78f1af7c9dfd0d21224eaa19fd6d 100644 --- a/sickbeard/providers/newznab.py +++ b/sickbeard/providers/newznab.py @@ -20,20 +20,22 @@ import urllib import time import datetime import os +import re import sickbeard import generic - +from sickbeard.common import Quality from sickbeard import classes from sickbeard import helpers from sickbeard import scene_exceptions from sickbeard import encodingKludge as ek from sickbeard import logger from sickbeard import tvcache +from sickbeard import db from sickbeard.exceptions import AuthException class NewznabProvider(generic.NZBProvider): - def __init__(self, name, url, key='', catIDs='5030,5040', search_mode='eponly', search_fallback=False, + def __init__(self, name, url, key='0', catIDs='5030,5040', search_mode='eponly', search_fallback=False, enable_daily=False, enable_backlog=False): generic.NZBProvider.__init__(self, name) @@ -66,6 +68,7 @@ class NewznabProvider(generic.NZBProvider): self.supportsBacklog = True self.default = False + self.last_search = datetime.datetime.now() def configStr(self): return self.name + '|' + self.url + '|' + self.key + '|' + self.catIDs + '|' + str( @@ -129,6 +132,8 @@ class NewznabProvider(generic.NZBProvider): to_return = [] cur_params = {} + cur_params['maxage'] = (datetime.datetime.now() - datetime.datetime.combine(ep_obj.airdate, datetime.datetime.min.time())).days + 1 + # season if ep_obj.show.air_by_date or ep_obj.show.sports: date_str = str(ep_obj.airdate).split('-')[0] @@ -142,9 +147,9 @@ class NewznabProvider(generic.NZBProvider): # search rid = helpers.mapIndexersToShow(ep_obj.show)[2] if rid: - cur_return = cur_params.copy() - cur_return['rid'] = rid - to_return.append(cur_return) + cur_params['rid'] = rid + elif 'rid' in params: + cur_params.pop('rid') # add new query strings for exceptions name_exceptions = list( @@ -152,6 +157,8 @@ class NewznabProvider(generic.NZBProvider): for cur_exception in name_exceptions: if 'q' in cur_params: cur_params['q'] = helpers.sanitizeSceneName(cur_exception) + '.' + cur_params['q'] + else: + cur_params['q'] = helpers.sanitizeSceneName(cur_exception) to_return.append(cur_params) return to_return @@ -163,6 +170,8 @@ class NewznabProvider(generic.NZBProvider): if not ep_obj: return [params] + params['maxage'] = (datetime.datetime.now() - datetime.datetime.combine(ep_obj.airdate, datetime.datetime.min.time())).days + 1 + if ep_obj.show.air_by_date or ep_obj.show.sports: date_str = str(ep_obj.airdate) params['season'] = date_str.partition('-')[0] @@ -176,30 +185,27 @@ class NewznabProvider(generic.NZBProvider): # search rid = helpers.mapIndexersToShow(ep_obj.show)[2] if rid: - cur_return = params.copy() - cur_return['rid'] = rid - to_return.append(cur_return) + params['rid'] = rid + elif 'rid' in params: + params.pop('rid') # add new query strings for exceptions name_exceptions = list( set(scene_exceptions.get_scene_exceptions(ep_obj.show.indexerid) + [ep_obj.show.name])) for cur_exception in name_exceptions: params['q'] = helpers.sanitizeSceneName(cur_exception) + if add_string: + params['q'] += ' ' + add_string + to_return.append(params) - + if ep_obj.show.anime: - # Experimental, add a searchstring without search explicitly for the episode! - # Remove the ?ep=e46 paramater and use add the episode number to the query paramater. - # Can be usefull for newznab indexers that do not have the episodes 100% parsed. - # Start with only applying the searchstring to anime shows - params['q'] = helpers.sanitizeSceneName(cur_exception) paramsNoEp = params.copy() - - paramsNoEp['q'] = paramsNoEp['q'] + " " + str(paramsNoEp['ep']) + paramsNoEp['q'] = paramsNoEp['q'] + " " + paramsNoEp['ep'] if "ep" in paramsNoEp: paramsNoEp.pop("ep") to_return.append(paramsNoEp) - + return to_return def _doGeneralSearch(self, search_string): @@ -222,6 +228,8 @@ class NewznabProvider(generic.NZBProvider): except:return self._checkAuth() try: + bozo = int(data['bozo']) + bozo_exception = data['bozo_exception'] err_code = int(data['feed']['error']['code']) err_desc = data['feed']['error']['description'] if not err_code or err_desc: @@ -236,6 +244,8 @@ class NewznabProvider(generic.NZBProvider): elif err_code == 102: raise AuthException( "Your account isn't allowed to use the API on " + self.name + ", contact the administrator") + elif bozo == 1: + raise Exception(bozo_exception) else: logger.log(u"Unknown error given from " + self.name + ": " + err_desc, logger.ERROR) @@ -257,9 +267,7 @@ class NewznabProvider(generic.NZBProvider): else: params['cat'] = self.catIDs - # if max_age is set, use it, don't allow it to be missing - if age or not params['maxage']: - params['maxage'] = age + params['maxage'] = (4, age)[age] if search_params: params.update(search_params) @@ -270,12 +278,18 @@ class NewznabProvider(generic.NZBProvider): results = [] offset = total = 0 - # Limit to 400 results, like Sick Beard does, to prevent throttling - while (total >= offset) and (offset <= 400): + while (total >= offset): search_url = self.url + 'api?' + urllib.urlencode(params) + + while((datetime.datetime.now() - self.last_search).seconds < 5): + time.sleep(1) + logger.log(u"Search url: " + search_url, logger.DEBUG) data = self.cache.getRSSFeed(search_url) + + self.last_search = datetime.datetime.now() + if not self._checkAuthFromData(data): break @@ -307,7 +321,7 @@ class NewznabProvider(generic.NZBProvider): break params['offset'] += params['limit'] - if (total > int(params['offset'])) and (int(params['offset']) <= 400): + if (total > int(params['offset'])) and (offset < 500): offset = int(params['offset']) # if there are more items available then the amount given in one call, grab some more logger.log(u'%d' % (total - offset) + ' more items to be fetched from provider.' + @@ -316,68 +330,33 @@ class NewznabProvider(generic.NZBProvider): logger.log(u'No more searches needed.', logger.DEBUG) break - time.sleep(0.2) - return results - def findPropers(self, search_date=None): - - search_terms = ['.proper.', '.repack.'] - - cache_results = self.cache.listPropers(search_date) - results = [classes.Proper(x['name'], x['url'], datetime.datetime.fromtimestamp(x['time']), self.show) for x in - cache_results] - - index = 0 - alt_search = ('nzbs_org' == self.getID()) - term_items_found = False - do_search_alt = False - - while index < len(search_terms): - search_params = {'q': search_terms[index]} - if alt_search: - - if do_search_alt: - index += 1 - - if term_items_found: - do_search_alt = True - term_items_found = False - else: - if do_search_alt: - search_params['t'] = "search" - - do_search_alt = (True, False)[do_search_alt] - - else: - index += 1 - - for item in self._doSearch(search_params, age=4): - - (title, url) = self._get_title_and_url(item) + def findPropers(self, search_date=datetime.datetime.today()): + results = [] - try: - result_date = datetime.datetime(*item['published_parsed'][0:6]) - except (AttributeError, KeyError): - try: - result_date = datetime.datetime(*item['updated_parsed'][0:6]) - except (AttributeError, KeyError): - try: - result_date = datetime.datetime(*item['created_parsed'][0:6]) - except (AttributeError, KeyError): - try: - result_date = datetime.datetime(*item['date'][0:6]) - except (AttributeError, KeyError): - logger.log(u"Unable to figure out the date for entry " + title + ", skipping it") - continue - - if not search_date or result_date > search_date: - search_result = classes.Proper(title, url, result_date, self.show) - results.append(search_result) - term_items_found = True - do_search_alt = False - - time.sleep(0.2) + myDB = db.DBConnection() + sqlResults = myDB.select( + 'SELECT s.show_name, e.showid, e.season, e.episode, e.status, e.airdate FROM tv_episodes AS e' + + ' INNER JOIN tv_shows AS s ON (e.showid = s.indexer_id)' + + ' WHERE e.airdate >= ' + str(search_date.toordinal()) + + ' AND (e.status IN (' + ','.join([str(x) for x in Quality.DOWNLOADED]) + ')' + + ' OR (e.status IN (' + ','.join([str(x) for x in Quality.SNATCHED]) + ')))' + ) + + if not sqlResults: + return [] + + for sqlshow in sqlResults: + self.show = helpers.findCertainShow(sickbeard.showList, int(sqlshow["showid"])) + if self.show: + curEp = self.show.getEpisode(int(sqlshow["season"]), int(sqlshow["episode"])) + searchStrings = self._get_episode_search_strings(curEp, add_string='PROPER|REPACK') + for searchString in searchStrings: + for item in self._doSearch(searchString): + title, url = self._get_title_and_url(item) + if(re.match(r'.*(REPACK|PROPER).*', title, re.I)): + results.append(classes.Proper(title, url, datetime.datetime.today(), self.show)) return results @@ -387,23 +366,32 @@ class NewznabCache(tvcache.TVCache): tvcache.TVCache.__init__(self, provider) - # only poll newznab providers every 30 minutes max, doubled so we don't get throttled again. + # only poll newznab providers every 30 minutes self.minTime = 30 + self.last_search = datetime.datetime.now() def _getRSSData(self): params = {"t": "tvsearch", "cat": self.provider.catIDs + ',5060,5070', - "attrs": "rageid"} + "attrs": "rageid", + "maxage": 4, + } if self.provider.needs_auth and self.provider.key: params['apikey'] = self.provider.key rss_url = self.provider.url + 'api?' + urllib.urlencode(params) + while((datetime.datetime.now() - self.last_search).seconds < 5): + time.sleep(1) + logger.log(self.provider.name + " cache update URL: " + rss_url, logger.DEBUG) + data = self.getRSSFeed(rss_url) + + self.last_search = datetime.datetime.now() - return self.getRSSFeed(rss_url) + return data def _checkAuth(self, data): return self.provider._checkAuthFromData(data) diff --git a/sickbeard/providers/rarbg.py b/sickbeard/providers/rarbg.py index 6b11252a90a9c5536d42f9b0cf89ef00063af1e4..6ee82f094880d11224de8420d5d4ab1355283587 100644 --- a/sickbeard/providers/rarbg.py +++ b/sickbeard/providers/rarbg.py @@ -112,7 +112,7 @@ class RarbgProvider(generic.TorrentProvider): response = self.session.get(self.urls['token'], timeout=30, verify=False, headers=self.headers) response.raise_for_status() resp_json = response.json() - except RequestException as e: + except (RequestException, BaseSSLError) as e: logger.log(u'Unable to connect to {name} provider: {error}'.format(name=self.name, error=ex(e)), logger.ERROR) return False diff --git a/sickbeard/providers/torrentday.py b/sickbeard/providers/torrentday.py index 6ba30f5044535a8ce90fdf888fab923a136498a5..e8476b0f1bbe9a4f45c5bc5a0fb6ceae933ba367 100644 --- a/sickbeard/providers/torrentday.py +++ b/sickbeard/providers/torrentday.py @@ -54,10 +54,10 @@ class TorrentDayProvider(generic.TorrentProvider): self.cache = TorrentDayCache(self) - self.urls = {'base_url': 'https://tdonline.org', - 'login': 'https://tdonline.org/torrents/', - 'search': 'https://tdonline.org/V3/API/API.php', - 'download': 'https://tdonline.org/download.php/%s/%s' + self.urls = {'base_url': 'https://classic.torrentday.com', + 'login': 'https://classic.torrentday.com/torrents/', + 'search': 'https://classic.torrentday.com/V3/API/API.php', + 'download': 'https://classic.torrentday.com/download.php/%s/%s' } self.url = self.urls['base_url'] diff --git a/sickbeard/rssfeeds.py b/sickbeard/rssfeeds.py index f87587bc63bb9de957e918a7ad3bddcf208495fd..0a06c49fa0fa6988be975f9d91a9378b64f08ac3 100644 --- a/sickbeard/rssfeeds.py +++ b/sickbeard/rssfeeds.py @@ -38,8 +38,8 @@ class RSSFeeds: url += urllib.urlencode(post_data) try: - resp = Cache(self.rssDB).fetch(url, force_update=True, request_headers=request_headers, handlers=handlers) + resp = Cache(self.rssDB, userAgent=sickbeard.common.USER_AGENT).fetch(url, force_update=True, request_headers=request_headers, handlers=handlers) finally: self.rssDB.close() - return resp \ No newline at end of file + return resp diff --git a/sickbeard/scene_numbering.py b/sickbeard/scene_numbering.py index e72fae4c02fb274d26c85a6d6c0ec864c4acc874..109929c471ef54d8d806772a8c7048c8f2a1d9cf 100644 --- a/sickbeard/scene_numbering.py +++ b/sickbeard/scene_numbering.py @@ -498,7 +498,7 @@ def xem_refresh(indexer_id, indexer, force=False): try: parsedJSON = sickbeard.helpers.getURL(url, json=True) if not parsedJSON or parsedJSON == '': - logger.log(u'No XEN data for show "%s on %s"' % (indexer_id, sickbeard.indexerApi(indexer).name,), logger.INFO) + logger.log(u'No XEM data for show "%s on %s"' % (indexer_id, sickbeard.indexerApi(indexer).name,), logger.INFO) return if 'success' in parsedJSON['result']: diff --git a/sickbeard/search.py b/sickbeard/search.py index e27f64b4a657c1178ee1ec8f5ea280d31894adcd..b8e8d0b9fcc9c70d9324abaa481d04833605a16f 100644 --- a/sickbeard/search.py +++ b/sickbeard/search.py @@ -380,7 +380,17 @@ def searchForNeededEpisodes(): for curProvider in providers: threading.currentThread().name = origThreadName + " :: [" + curProvider.name + "]" - curFoundResults = curProvider.searchRSS(episodes) + curFoundResults = {} + try: + curFoundResults = curProvider.searchRSS(episodes) + except exceptions.AuthException, e: + logger.log(u"Authentication error: " + ex(e), logger.ERROR) + continue + except Exception, e: + logger.log(u"Error while searching " + curProvider.name + ", skipping: " + ex(e), logger.ERROR) + logger.log(traceback.format_exc(), logger.DEBUG) + continue + didSearch = True # pick a single result for each episode, respecting existing results diff --git a/sickbeard/tv.py b/sickbeard/tv.py index 1b0b8199e7ac9fd68a6e0b1103eb0e08da3c5cb6..c70d7efb4f3e058613c6860a7127739809852f20 100644 --- a/sickbeard/tv.py +++ b/sickbeard/tv.py @@ -2500,7 +2500,7 @@ class TVEpisode(object): # split off the dirs only, if they exist name_groups = re.split(r'[\\/]', pattern) - return self._format_pattern(name_groups[-1], multi, anime_type) + return helpers.sanitizeFileName(self._format_pattern(name_groups[-1], multi, anime_type)) def rename(self): """ diff --git a/sickbeard/tvcache.py b/sickbeard/tvcache.py index 3e2804cba01b83d6181ddcdda6d9b1a3af9bd3a9..6b3db41511c30ba961de401a44e62328c736240b 100644 --- a/sickbeard/tvcache.py +++ b/sickbeard/tvcache.py @@ -35,7 +35,7 @@ from sickbeard.exceptions import AuthException from sickbeard.rssfeeds import RSSFeeds from name_parser.parser import NameParser, InvalidNameException, InvalidShowException from sickbeard import encodingKludge as ek - +from sickbeard import show_name_helpers class CacheDBConnection(db.DBConnection): def __init__(self, providerName): @@ -316,7 +316,9 @@ class TVCache(): cl = [] myDB = self._getDB() - if type(episode) != list: + if not episode: + sqlResults = myDB.select("SELECT * FROM [" + self.providerID + "]") + elif type(episode) != list: sqlResults = myDB.select( "SELECT * FROM [" + self.providerID + "] WHERE indexerid = ? AND season = ? AND episodes LIKE ?", [episode.show.indexerid, episode.season, "%|" + str(episode.episode) + "|%"]) @@ -332,6 +334,10 @@ class TVCache(): # for each cache entry for curResult in sqlResults: + # ignored/required words, and non-tv junk + if not show_name_helpers.filterBadReleases(curResult["name"]): + continue + # get the show object, or if it's not one of our shows then ignore it showObj = helpers.findCertainShow(sickbeard.showList, int(curResult["indexerid"])) if not showObj: @@ -346,9 +352,11 @@ class TVCache(): curSeason = int(curResult["season"]) if curSeason == -1: continue + curEp = curResult["episodes"].split("|")[1] if not curEp: continue + curEp = int(curEp) curQuality = int(curResult["quality"]) diff --git a/sickbeard/webserve.py b/sickbeard/webserve.py index 7d99e1a615aaf4950f3bc2fbb43b31fc95401196..aeef2a29bbd0733c5a50639d13a7c85126a5305e 100644 --- a/sickbeard/webserve.py +++ b/sickbeard/webserve.py @@ -1540,7 +1540,7 @@ class Home(WebRoot): if not sickbeard.traktRollingScheduler.action.updateWantedList(showObj.indexerid): errors.append("Unable to force an update on wanted episode") - ui.notifications.message('<b>%s</b> has been %s' % (showObj.name,('resumed', 'paused')[showObj.paused])) + ui.notifications.message('%s has been %s' % (showObj.name,('resumed', 'paused')[showObj.paused])) return self.redirect("/home/displayShow?show=" + show) def deleteShow(self, show=None, full=0): @@ -1567,7 +1567,7 @@ class Home(WebRoot): showObj.deleteShow(bool(full)) - ui.notifications.message('<b>%s</b> has been %s %s' % + ui.notifications.message('%s has been %s %s' % (showObj.name, ('deleted', 'trashed')[sickbeard.TRASH_REMOVE_SHOW], ('(media untouched)', '(with all related media)')[bool(full)])) @@ -3680,7 +3680,7 @@ class ConfigGeneral(Config): def saveGeneral(self, log_dir=None, log_nr = 5, log_size = 1048576, web_port=None, web_log=None, encryption_version=None, web_ipv6=None, update_shows_on_start=None, update_shows_on_snatch=None, trash_remove_show=None, trash_rotate_logs=None, update_frequency=None, indexerDefaultLang='en', ep_default_deleted_status=None, launch_browser=None, showupdate_hour=3, web_username=None, - api_key=None, indexer_default=None, timezone_display=None, cpu_preset=None, + api_key=None, indexer_default=None, timezone_display=None, cpu_preset='NORMAL', web_password=None, version_notify=None, enable_https=None, https_cert=None, https_key=None, handle_reverse_proxy=None, sort_article=None, auto_update=None, notify_on_update=None, proxy_setting=None, proxy_indexers=None, anon_redirect=None, git_path=None, git_remote=None, diff --git a/tests/feedparser_tests.py b/tests/feedparser_tests.py index d7505d9419e37720714460957d4e131e0fa84e48..b5c02f1f8136f2b7b6760c320111540d3d388dc7 100644 --- a/tests/feedparser_tests.py +++ b/tests/feedparser_tests.py @@ -6,14 +6,15 @@ sys.path.insert(1, os.path.abspath(os.path.join(os.path.dirname(__file__), '../l sys.path.insert(1, os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))) from sickbeard.rssfeeds import RSSFeeds - +from sickbeard.tvcache import TVCache class FeedParserTests(unittest.TestCase): - def test_newznab(self): + def test_womble(self): RSSFeeds().clearCache() - result = RSSFeeds().getFeed('http://lolo.sickbeard.com/api?t=caps') + result = RSSFeeds().getFeed('https://newshost.co.za/rss/?sec=tv-sd&fr=false') self.assertTrue('entries' in result) self.assertTrue('feed' in result) - self.assertTrue('categories' in result.feed) + for item in result['entries']: + self.assertTrue(TVCache._parseItem(item)) if __name__ == "__main__": print "==================" diff --git a/tests/snatch_tests.py b/tests/snatch_tests.py index ad06483beb3ca882651fac60fd8fce5ae7240ad4..6f68c9366103d0f45c34ee7fc6af82208f871055 100644 --- a/tests/snatch_tests.py +++ b/tests/snatch_tests.py @@ -83,7 +83,7 @@ def test_generator(tvdbdid, show_name, curData, forceSearch): episode.status = c.WANTED episode.saveToDB() - bestResult = search.searchProviders(show, episode.season, episode.episode, forceSearch) + bestResult = search.searchProviders(show, episode.episode, forceSearch) if not bestResult: self.assertEqual(curData["b"], bestResult) self.assertEqual(curData["b"], bestResult.name) #first is expected, second is choosen one diff --git a/tests/test_lib.py b/tests/test_lib.py index 841b3a9688dfd243d716c810896610522b56e079..a4077eddc5a587081f62f0e47ea35cfc0e467e03 100644 --- a/tests/test_lib.py +++ b/tests/test_lib.py @@ -43,7 +43,7 @@ shutil.copyfile = lib.shutil_custom.copyfile_custom #================= # test globals #================= -TESTDIR = os.path.abspath('.') +TESTDIR = os.path.abspath(os.path.dirname(__file__)) TESTDBNAME = "sickbeard.db" TESTCACHEDBNAME = "cache.db" TESTFAILEDDBNAME = "failed.db" @@ -88,8 +88,8 @@ sickbeard.PROVIDER_ORDER = ["sick_beard_index"] sickbeard.newznabProviderList = providers.getNewznabProviderList("'Sick Beard Index|http://lolo.sickbeard.com/|0|5030,5040|0|eponly|0|0|0!!!NZBs.org|https://nzbs.org/||5030,5040,5060,5070,5090|0|eponly|0|0|0!!!Usenet-Crawler|https://www.usenet-crawler.com/||5030,5040,5060|0|eponly|0|0|0'") sickbeard.providerList = providers.makeProviderList() -sickbeard.PROG_DIR = os.path.abspath('..') -sickbeard.DATA_DIR = sickbeard.PROG_DIR +sickbeard.PROG_DIR = os.path.abspath(os.path.join(TESTDIR, '..')) +sickbeard.DATA_DIR = TESTDIR sickbeard.CONFIG_FILE = os.path.join(sickbeard.DATA_DIR, "config.ini") sickbeard.CFG = ConfigObj(sickbeard.CONFIG_FILE) @@ -140,34 +140,36 @@ class SickbeardTestDBCase(unittest.TestCase): tearDown_test_episode_file() tearDown_test_show_dir() - class TestDBConnection(db.DBConnection, object): def __init__(self, dbFileName=TESTDBNAME): dbFileName = os.path.join(TESTDIR, dbFileName) super(TestDBConnection, self).__init__(dbFileName) - class TestCacheDBConnection(TestDBConnection, object): - - def __init__(self, providerName): + def __init__(self, providerName): db.DBConnection.__init__(self, os.path.join(TESTDIR, TESTCACHEDBNAME)) # Create the table if it's not already there try: - sql = "CREATE TABLE " + providerName + " (name TEXT, season NUMERIC, episodes TEXT, indexerid NUMERIC, url TEXT, time NUMERIC, quality TEXT);" - self.connection.execute(sql) - self.connection.commit() - except sqlite3.OperationalError, e: - if str(e) != "table " + providerName + " already exists": + if not self.hasTable(providerName): + sql = "CREATE TABLE [" + providerName + "] (name TEXT, season NUMERIC, episodes TEXT, indexerid NUMERIC, url TEXT, time NUMERIC, quality TEXT, release_group TEXT)" + self.connection.execute(sql) + self.connection.commit() + except Exception, e: + if str(e) != "table [" + providerName + "] already exists": raise + # add version column to table if missing + if not self.hasColumn(providerName, 'version'): + self.addColumn(providerName, 'version', "NUMERIC", "-1") + # Create the table if it's not already there try: sql = "CREATE TABLE lastUpdate (provider TEXT, time NUMERIC);" self.connection.execute(sql) self.connection.commit() - except sqlite3.OperationalError, e: + except Exception, e: if str(e) != "table lastUpdate already exists": raise @@ -195,31 +197,29 @@ def setUp_test_db(): def tearDown_test_db(): - """Deletes the test db - although this seams not to work on my system it leaves me with an zero kb file - """ - - #uncomment next line so leave the db intact between test and at the end - #return False - - for current_db in [ TESTDBNAME, TESTCACHEDBNAME, TESTFAILEDDBNAME ]: - for file_name in [ os.path.join(TESTDIR, current_db), os.path.join(TESTDIR, current_db + '-journal') ]: - if os.path.exists(file_name): - try: - os.remove(file_name) - except (IOError, OSError) as e: - print 'ERROR: Failed to remove ' + file_name - print ex(e) - + from sickbeard.db import db_cons + for connection in db_cons: + db_cons[connection].commit() +# db_cons[connection].close() + +# for current_db in [ TESTDBNAME, TESTCACHEDBNAME, TESTFAILEDDBNAME ]: +# file_name = os.path.join(TESTDIR, current_db) +# if os.path.exists(file_name): +# try: +# os.remove(file_name) +# except Exception as e: +# print 'ERROR: Failed to remove ' + file_name +# print ex(e) def setUp_test_episode_file(): if not os.path.exists(FILEDIR): os.makedirs(FILEDIR) try: - with open(FILEPATH, 'w') as f: + with open(FILEPATH, 'wb') as f: f.write("foo bar") - except EnvironmentError: + f.flush() + except Exception: print "Unable to set up test episode" raise diff --git a/tests/tv_tests.py b/tests/tv_tests.py index 9b130d17cd019ab79fe66990c550cc9d2ce13f5e..a6b8cb30957b46e1a090abd359cd30d8f816a9b5 100644 --- a/tests/tv_tests.py +++ b/tests/tv_tests.py @@ -43,7 +43,7 @@ class TVShowTests(test.SickbeardTestDBCase): show.network = "cbs" show.genre = "crime" show.runtime = 40 - show.status = "5" + show.status = "Ended" show.default_ep_status = "5" show.airs = "monday" show.startyear = 1987 @@ -92,7 +92,7 @@ class TVTests(test.SickbeardTestDBCase): show.network = "cbs" show.genre = "crime" show.runtime = 40 - show.status = "5" + show.status = "Ended" show.default_ep_status = "5" show.airs = "monday" show.startyear = 1987 diff --git a/tornado/__init__.py b/tornado/__init__.py index 0e39f842c7c9e0a190c7877b2a51e3c6741920a2..6f4f47d2d9f7f60e8c18e4583579e2681f56e742 100644 --- a/tornado/__init__.py +++ b/tornado/__init__.py @@ -25,5 +25,5 @@ from __future__ import absolute_import, division, print_function, with_statement # is zero for an official release, positive for a development branch, # or negative for a release candidate or beta (after the base version # number has been incremented) -version = "4.1.dev1" -version_info = (4, 1, 0, -100) +version = "4.1" +version_info = (4, 1, 0, 0) diff --git a/tornado/autoreload.py b/tornado/autoreload.py index 3982579ad73024072214116cb4ac418616121400..a548cf02624f1afc42edf15ea2fa5717f709d589 100644 --- a/tornado/autoreload.py +++ b/tornado/autoreload.py @@ -108,7 +108,11 @@ _io_loops = weakref.WeakKeyDictionary() def start(io_loop=None, check_time=500): - """Begins watching source files for changes using the given `.IOLoop`. """ + """Begins watching source files for changes. + + .. versionchanged:: 4.1 + The ``io_loop`` argument is deprecated. + """ io_loop = io_loop or ioloop.IOLoop.current() if io_loop in _io_loops: return diff --git a/tornado/concurrent.py b/tornado/concurrent.py index 6bab5d2e0add53c17682cd3a1058829218823ad0..acfbcd83e820dfd1ff1e6eb9be8af961d5ae0084 100644 --- a/tornado/concurrent.py +++ b/tornado/concurrent.py @@ -25,11 +25,13 @@ module. from __future__ import absolute_import, division, print_function, with_statement import functools +import platform +import traceback import sys +from tornado.log import app_log from tornado.stack_context import ExceptionStackContext, wrap from tornado.util import raise_exc_info, ArgReplacer -from tornado.log import app_log try: from concurrent import futures @@ -37,9 +39,88 @@ except ImportError: futures = None +# Can the garbage collector handle cycles that include __del__ methods? +# This is true in cpython beginning with version 3.4 (PEP 442). +_GC_CYCLE_FINALIZERS = (platform.python_implementation() == 'CPython' and + sys.version_info >= (3, 4)) + class ReturnValueIgnoredError(Exception): pass +# This class and associated code in the future object is derived +# from the Trollius project, a backport of asyncio to Python 2.x - 3.x + +class _TracebackLogger(object): + """Helper to log a traceback upon destruction if not cleared. + + This solves a nasty problem with Futures and Tasks that have an + exception set: if nobody asks for the exception, the exception is + never logged. This violates the Zen of Python: 'Errors should + never pass silently. Unless explicitly silenced.' + + However, we don't want to log the exception as soon as + set_exception() is called: if the calling code is written + properly, it will get the exception and handle it properly. But + we *do* want to log it if result() or exception() was never called + -- otherwise developers waste a lot of time wondering why their + buggy code fails silently. + + An earlier attempt added a __del__() method to the Future class + itself, but this backfired because the presence of __del__() + prevents garbage collection from breaking cycles. A way out of + this catch-22 is to avoid having a __del__() method on the Future + class itself, but instead to have a reference to a helper object + with a __del__() method that logs the traceback, where we ensure + that the helper object doesn't participate in cycles, and only the + Future has a reference to it. + + The helper object is added when set_exception() is called. When + the Future is collected, and the helper is present, the helper + object is also collected, and its __del__() method will log the + traceback. When the Future's result() or exception() method is + called (and a helper object is present), it removes the the helper + object, after calling its clear() method to prevent it from + logging. + + One downside is that we do a fair amount of work to extract the + traceback from the exception, even when it is never logged. It + would seem cheaper to just store the exception object, but that + references the traceback, which references stack frames, which may + reference the Future, which references the _TracebackLogger, and + then the _TracebackLogger would be included in a cycle, which is + what we're trying to avoid! As an optimization, we don't + immediately format the exception; we only do the work when + activate() is called, which call is delayed until after all the + Future's callbacks have run. Since usually a Future has at least + one callback (typically set by 'yield From') and usually that + callback extracts the callback, thereby removing the need to + format the exception. + + PS. I don't claim credit for this solution. I first heard of it + in a discussion about closing files when they are collected. + """ + + __slots__ = ('exc_info', 'formatted_tb') + + def __init__(self, exc_info): + self.exc_info = exc_info + self.formatted_tb = None + + def activate(self): + exc_info = self.exc_info + if exc_info is not None: + self.exc_info = None + self.formatted_tb = traceback.format_exception(*exc_info) + + def clear(self): + self.exc_info = None + self.formatted_tb = None + + def __del__(self): + if self.formatted_tb: + app_log.error('Future exception was never retrieved: %s', + ''.join(self.formatted_tb).rstrip()) + class Future(object): """Placeholder for an asynchronous result. @@ -68,12 +149,23 @@ class Future(object): if that package was available and fall back to the thread-unsafe implementation if it was not. + .. versionchanged:: 4.1 + If a `.Future` contains an error but that error is never observed + (by calling ``result()``, ``exception()``, or ``exc_info()``), + a stack trace will be logged when the `.Future` is garbage collected. + This normally indicates an error in the application, but in cases + where it results in undesired logging it may be necessary to + suppress the logging by ensuring that the exception is observed: + ``f.add_done_callback(lambda f: f.exception())``. """ def __init__(self): self._done = False self._result = None - self._exception = None self._exc_info = None + + self._log_traceback = False # Used for Python >= 3.4 + self._tb_logger = None # Used for Python <= 3.3 + self._callbacks = [] def cancel(self): @@ -100,16 +192,21 @@ class Future(object): """Returns True if the future has finished running.""" return self._done + def _clear_tb_log(self): + self._log_traceback = False + if self._tb_logger is not None: + self._tb_logger.clear() + self._tb_logger = None + def result(self, timeout=None): """If the operation succeeded, return its result. If it failed, re-raise its exception. """ + self._clear_tb_log() if self._result is not None: return self._result if self._exc_info is not None: raise_exc_info(self._exc_info) - elif self._exception is not None: - raise self._exception self._check_done() return self._result @@ -117,8 +214,9 @@ class Future(object): """If the operation raised an exception, return the `Exception` object. Otherwise returns None. """ - if self._exception is not None: - return self._exception + self._clear_tb_log() + if self._exc_info is not None: + return self._exc_info[1] else: self._check_done() return None @@ -147,14 +245,17 @@ class Future(object): def set_exception(self, exception): """Sets the exception of a ``Future.``""" - self._exception = exception - self._set_done() + self.set_exc_info( + (exception.__class__, + exception, + getattr(exception, '__traceback__', None))) def exc_info(self): """Returns a tuple in the same format as `sys.exc_info` or None. .. versionadded:: 4.0 """ + self._clear_tb_log() return self._exc_info def set_exc_info(self, exc_info): @@ -165,7 +266,18 @@ class Future(object): .. versionadded:: 4.0 """ self._exc_info = exc_info - self.set_exception(exc_info[1]) + self._log_traceback = True + if not _GC_CYCLE_FINALIZERS: + self._tb_logger = _TracebackLogger(exc_info) + + try: + self._set_done() + finally: + # Activate the logger after all callbacks have had a + # chance to call result() or exception(). + if self._log_traceback and self._tb_logger is not None: + self._tb_logger.activate() + self._exc_info = exc_info def _check_done(self): if not self._done: @@ -181,6 +293,21 @@ class Future(object): cb, self) self._callbacks = None + # On Python 3.3 or older, objects with a destructor part of a reference + # cycle are never destroyed. It's no longer the case on Python 3.4 thanks to + # the PEP 442. + if _GC_CYCLE_FINALIZERS: + def __del__(self): + if not self._log_traceback: + # set_exception() was not called, or result() or exception() + # has consumed the exception + return + + tb = traceback.format_exception(*self._exc_info) + + app_log.error('Future %r exception was never retrieved: %s', + self, ''.join(tb).rstrip()) + TracebackFuture = Future if futures is None: @@ -293,7 +420,7 @@ def return_future(f): # If the initial synchronous part of f() raised an exception, # go ahead and raise it to the caller directly without waiting # for them to inspect the Future. - raise_exc_info(exc_info) + future.result() # If the caller passed in a callback, schedule it to be called # when the future resolves. It is important that this happens diff --git a/tornado/curl_httpclient.py b/tornado/curl_httpclient.py index 68047cc94880560e21a6e36af0c1010948abe3f2..ebbe0e84b9300485d1acae2ed405b4cf512b0e21 100644 --- a/tornado/curl_httpclient.py +++ b/tornado/curl_httpclient.py @@ -28,12 +28,12 @@ from io import BytesIO from tornado import httputil from tornado import ioloop -from tornado.log import gen_log from tornado import stack_context from tornado.escape import utf8, native_str from tornado.httpclient import HTTPResponse, HTTPError, AsyncHTTPClient, main +curl_log = logging.getLogger('tornado.curl_httpclient') class CurlAsyncHTTPClient(AsyncHTTPClient): def initialize(self, io_loop, max_clients=10, defaults=None): @@ -257,7 +257,7 @@ class CurlAsyncHTTPClient(AsyncHTTPClient): def _curl_create(self): curl = pycurl.Curl() - if gen_log.isEnabledFor(logging.DEBUG): + if curl_log.isEnabledFor(logging.DEBUG): curl.setopt(pycurl.VERBOSE, 1) curl.setopt(pycurl.DEBUGFUNCTION, self._curl_debug) return curl @@ -403,11 +403,11 @@ class CurlAsyncHTTPClient(AsyncHTTPClient): raise ValueError("Unsupported auth_mode %s" % request.auth_mode) curl.setopt(pycurl.USERPWD, native_str(userpwd)) - gen_log.debug("%s %s (username: %r)", request.method, request.url, + curl_log.debug("%s %s (username: %r)", request.method, request.url, request.auth_username) else: curl.unsetopt(pycurl.USERPWD) - gen_log.debug("%s %s", request.method, request.url) + curl_log.debug("%s %s", request.method, request.url) if request.client_cert is not None: curl.setopt(pycurl.SSLCERT, request.client_cert) @@ -448,12 +448,12 @@ class CurlAsyncHTTPClient(AsyncHTTPClient): def _curl_debug(self, debug_type, debug_msg): debug_types = ('I', '<', '>', '<', '>') if debug_type == 0: - gen_log.debug('%s', debug_msg.strip()) + curl_log.debug('%s', debug_msg.strip()) elif debug_type in (1, 2): for line in debug_msg.splitlines(): - gen_log.debug('%s %s', debug_types[debug_type], line) + curl_log.debug('%s %s', debug_types[debug_type], line) elif debug_type == 4: - gen_log.debug('%s %r', debug_types[debug_type], debug_msg) + curl_log.debug('%s %r', debug_types[debug_type], debug_msg) class CurlError(HTTPError): diff --git a/tornado/gen.py b/tornado/gen.py index 2fc9b0c70538bcb7694872049e5fe4a8e07ab07c..86fe2f19596ea77267af838b65c7faa03541b16a 100644 --- a/tornado/gen.py +++ b/tornado/gen.py @@ -43,8 +43,21 @@ be returned when they are all finished:: response3 = response_dict['response3'] response4 = response_dict['response4'] +If the `~functools.singledispatch` library is available (standard in +Python 3.4, available via the `singledispatch +<https://pypi.python.org/pypi/singledispatch>`_ package on older +versions), additional types of objects may be yielded. Tornado includes +support for ``asyncio.Future`` and Twisted's ``Deferred`` class when +``tornado.platform.asyncio`` and ``tornado.platform.twisted`` are imported. +See the `convert_yielded` function to extend this mechanism. + .. versionchanged:: 3.2 Dict support added. + +.. versionchanged:: 4.1 + Support added for yielding ``asyncio`` Futures and Twisted Deferreds + via ``singledispatch``. + """ from __future__ import absolute_import, division, print_function, with_statement @@ -53,11 +66,21 @@ import functools import itertools import sys import types +import weakref from tornado.concurrent import Future, TracebackFuture, is_future, chain_future from tornado.ioloop import IOLoop +from tornado.log import app_log from tornado import stack_context +try: + from functools import singledispatch # py34+ +except ImportError as e: + try: + from singledispatch import singledispatch # backport + except ImportError: + singledispatch = None + class KeyReuseError(Exception): pass @@ -240,6 +263,106 @@ class Return(Exception): super(Return, self).__init__() self.value = value +class WaitIterator(object): + """Provides an iterator to yield the results of futures as they finish. + + Yielding a set of futures like this: + + ``results = yield [future1, future2]`` + + pauses the coroutine until both ``future1`` and ``future2`` + return, and then restarts the coroutine with the results of both + futures. If either future is an exception, the expression will + raise that exception and all the results will be lost. + + If you need to get the result of each future as soon as possible, + or if you need the result of some futures even if others produce + errors, you can use ``WaitIterator``: + + :: + + wait_iterator = gen.WaitIterator(future1, future2) + while not wait_iterator.done(): + try: + result = yield wait_iterator.next() + except Exception as e: + print "Error {} from {}".format(e, wait_iterator.current_future) + else: + print "Result {} recieved from {} at {}".format( + result, wait_iterator.current_future, + wait_iterator.current_index) + + Because results are returned as soon as they are available the + output from the iterator *will not be in the same order as the + input arguments*. If you need to know which future produced the + current result, you can use the attributes + ``WaitIterator.current_future``, or ``WaitIterator.current_index`` + to get the index of the future from the input list. (if keyword + arguments were used in the construction of the `WaitIterator`, + ``current_index`` will use the corresponding keyword). + + .. versionadded:: 4.1 + """ + def __init__(self, *args, **kwargs): + if args and kwargs: + raise ValueError( + "You must provide args or kwargs, not both") + + if kwargs: + self._unfinished = dict((f, k) for (k, f) in kwargs.items()) + futures = list(kwargs.values()) + else: + self._unfinished = dict((f, i) for (i, f) in enumerate(args)) + futures = args + + self._finished = collections.deque() + self.current_index = self.current_future = None + self._running_future = None + + self_ref = weakref.ref(self) + for future in futures: + future.add_done_callback(functools.partial( + self._done_callback, self_ref)) + + def done(self): + """Returns True if this iterator has no more results.""" + if self._finished or self._unfinished: + return False + # Clear the 'current' values when iteration is done. + self.current_index = self.current_future = None + return True + + def next(self): + """Returns a `.Future` that will yield the next available result. + + Note that this `.Future` will not be the same object as any of + the inputs. + """ + self._running_future = TracebackFuture() + + if self._finished: + self._return_result(self._finished.popleft()) + + return self._running_future + + @staticmethod + def _done_callback(self_ref, done): + self = self_ref() + if self is not None: + if self._running_future and not self._running_future.done(): + self._return_result(done) + else: + self._finished.append(done) + + def _return_result(self, done): + """Called set the returned future's state that of the future + we yielded, and set the current future for the iterator. + """ + chain_future(done, self._running_future) + + self.current_future = done + self.current_index = self._unfinished.pop(done) + class YieldPoint(object): """Base class for objects that may be yielded from the generator. @@ -371,6 +494,11 @@ def Task(func, *args, **kwargs): class YieldFuture(YieldPoint): def __init__(self, future, io_loop=None): + """Adapts a `.Future` to the `YieldPoint` interface. + + .. versionchanged:: 4.1 + The ``io_loop`` argument is deprecated. + """ self.future = future self.io_loop = io_loop or IOLoop.current() @@ -504,7 +632,7 @@ def maybe_future(x): return fut -def with_timeout(timeout, future, io_loop=None): +def with_timeout(timeout, future, io_loop=None, quiet_exceptions=()): """Wraps a `.Future` in a timeout. Raises `TimeoutError` if the input future does not complete before @@ -512,9 +640,17 @@ def with_timeout(timeout, future, io_loop=None): `.IOLoop.add_timeout` (i.e. a `datetime.timedelta` or an absolute time relative to `.IOLoop.time`) + If the wrapped `.Future` fails after it has timed out, the exception + will be logged unless it is of a type contained in ``quiet_exceptions`` + (which may be an exception type or a sequence of types). + Currently only supports Futures, not other `YieldPoint` classes. .. versionadded:: 4.0 + + .. versionchanged:: 4.1 + Added the ``quiet_exceptions`` argument and the logging of unhandled + exceptions. """ # TODO: allow yield points in addition to futures? # Tricky to do with stack_context semantics. @@ -528,9 +664,19 @@ def with_timeout(timeout, future, io_loop=None): chain_future(future, result) if io_loop is None: io_loop = IOLoop.current() + def error_callback(future): + try: + future.result() + except Exception as e: + if not isinstance(e, quiet_exceptions): + app_log.error("Exception in Future %r after timeout", + future, exc_info=True) + def timeout_callback(): + result.set_exception(TimeoutError("Timeout")) + # In case the wrapped future goes on to fail, log it. + future.add_done_callback(error_callback) timeout_handle = io_loop.add_timeout( - timeout, - lambda: result.set_exception(TimeoutError("Timeout"))) + timeout, timeout_callback) if isinstance(future, Future): # We know this future will resolve on the IOLoop, so we don't # need the extra thread-safety of IOLoop.add_future (and we also @@ -545,6 +691,25 @@ def with_timeout(timeout, future, io_loop=None): return result +def sleep(duration): + """Return a `.Future` that resolves after the given number of seconds. + + When used with ``yield`` in a coroutine, this is a non-blocking + analogue to `time.sleep` (which should not be used in coroutines + because it is blocking):: + + yield gen.sleep(0.5) + + Note that calling this function on its own does nothing; you must + wait on the `.Future` it returns (usually by yielding it). + + .. versionadded:: 4.1 + """ + f = Future() + IOLoop.current().call_later(duration, lambda: f.set_result(None)) + return f + + _null_future = Future() _null_future.set_result(None) @@ -678,18 +843,18 @@ class Runner(object): self.running = False def handle_yield(self, yielded): - if isinstance(yielded, list): - if all(is_future(f) for f in yielded): - yielded = multi_future(yielded) - else: - yielded = Multi(yielded) - elif isinstance(yielded, dict): - if all(is_future(f) for f in yielded.values()): - yielded = multi_future(yielded) - else: - yielded = Multi(yielded) + # Lists containing YieldPoints require stack contexts; + # other lists are handled via multi_future in convert_yielded. + if (isinstance(yielded, list) and + any(isinstance(f, YieldPoint) for f in yielded)): + yielded = Multi(yielded) + elif (isinstance(yielded, dict) and + any(isinstance(f, YieldPoint) for f in yielded.values())): + yielded = Multi(yielded) if isinstance(yielded, YieldPoint): + # YieldPoints are too closely coupled to the Runner to go + # through the generic convert_yielded mechanism. self.future = TracebackFuture() def start_yield_point(): try: @@ -702,6 +867,7 @@ class Runner(object): except Exception: self.future = TracebackFuture() self.future.set_exc_info(sys.exc_info()) + if self.stack_context_deactivate is None: # Start a stack context if this is the first # YieldPoint we've seen. @@ -715,16 +881,17 @@ class Runner(object): return False else: start_yield_point() - elif is_future(yielded): - self.future = yielded - if not self.future.done() or self.future is moment: - self.io_loop.add_future( - self.future, lambda f: self.run()) - return False else: - self.future = TracebackFuture() - self.future.set_exception(BadYieldError( - "yielded unknown object %r" % (yielded,))) + try: + self.future = convert_yielded(yielded) + except BadYieldError: + self.future = TracebackFuture() + self.future.set_exc_info(sys.exc_info()) + + if not self.future.done() or self.future is moment: + self.io_loop.add_future( + self.future, lambda f: self.run()) + return False return True def result_callback(self, key): @@ -763,3 +930,30 @@ def _argument_adapter(callback): else: callback(None) return wrapper + + +def convert_yielded(yielded): + """Convert a yielded object into a `.Future`. + + The default implementation accepts lists, dictionaries, and Futures. + + If the `~functools.singledispatch` library is available, this function + may be extended to support additional types. For example:: + + @convert_yielded.register(asyncio.Future) + def _(asyncio_future): + return tornado.platform.asyncio.to_tornado_future(asyncio_future) + + .. versionadded:: 4.1 + """ + # Lists and dicts containing YieldPoints were handled separately + # via Multi(). + if isinstance(yielded, (list, dict)): + return multi_future(yielded) + elif is_future(yielded): + return yielded + else: + raise BadYieldError("yielded unknown object %r" % (yielded,)) + +if singledispatch is not None: + convert_yielded = singledispatch(convert_yielded) diff --git a/tornado/http1connection.py b/tornado/http1connection.py index 90895cc94393d8ef0fd662f759bc4b1d76f0d931..181319c42e4e192c74cc4bb20dc80d51b3bbc86a 100644 --- a/tornado/http1connection.py +++ b/tornado/http1connection.py @@ -162,7 +162,8 @@ class HTTP1Connection(httputil.HTTPConnection): header_data = yield gen.with_timeout( self.stream.io_loop.time() + self.params.header_timeout, header_future, - io_loop=self.stream.io_loop) + io_loop=self.stream.io_loop, + quiet_exceptions=iostream.StreamClosedError) except gen.TimeoutError: self.close() raise gen.Return(False) @@ -221,7 +222,8 @@ class HTTP1Connection(httputil.HTTPConnection): try: yield gen.with_timeout( self.stream.io_loop.time() + self._body_timeout, - body_future, self.stream.io_loop) + body_future, self.stream.io_loop, + quiet_exceptions=iostream.StreamClosedError) except gen.TimeoutError: gen_log.info("Timeout reading body from %s", self.context) @@ -326,8 +328,10 @@ class HTTP1Connection(httputil.HTTPConnection): def write_headers(self, start_line, headers, chunk=None, callback=None): """Implements `.HTTPConnection.write_headers`.""" + lines = [] if self.is_client: self._request_start_line = start_line + lines.append(utf8('%s %s HTTP/1.1' % (start_line[0], start_line[1]))) # Client requests with a non-empty body must have either a # Content-Length or a Transfer-Encoding. self._chunking_output = ( @@ -336,6 +340,7 @@ class HTTP1Connection(httputil.HTTPConnection): 'Transfer-Encoding' not in headers) else: self._response_start_line = start_line + lines.append(utf8('HTTP/1.1 %s %s' % (start_line[1], start_line[2]))) self._chunking_output = ( # TODO: should this use # self._request_start_line.version or @@ -365,7 +370,6 @@ class HTTP1Connection(httputil.HTTPConnection): self._expected_content_remaining = int(headers['Content-Length']) else: self._expected_content_remaining = None - lines = [utf8("%s %s %s" % start_line)] lines.extend([utf8(n) + b": " + utf8(v) for n, v in headers.get_all()]) for line in lines: if b'\n' in line: @@ -374,6 +378,7 @@ class HTTP1Connection(httputil.HTTPConnection): if self.stream.closed(): future = self._write_future = Future() future.set_exception(iostream.StreamClosedError()) + future.exception() else: if callback is not None: self._write_callback = stack_context.wrap(callback) @@ -412,6 +417,7 @@ class HTTP1Connection(httputil.HTTPConnection): if self.stream.closed(): future = self._write_future = Future() self._write_future.set_exception(iostream.StreamClosedError()) + self._write_future.exception() else: if callback is not None: self._write_callback = stack_context.wrap(callback) @@ -451,6 +457,9 @@ class HTTP1Connection(httputil.HTTPConnection): self._pending_write.add_done_callback(self._finish_request) def _on_write_complete(self, future): + exc = future.exception() + if exc is not None and not isinstance(exc, iostream.StreamClosedError): + future.result() if self._write_callback is not None: callback = self._write_callback self._write_callback = None @@ -491,8 +500,9 @@ class HTTP1Connection(httputil.HTTPConnection): # we SHOULD ignore at least one empty line before the request. # http://tools.ietf.org/html/rfc7230#section-3.5 data = native_str(data.decode('latin1')).lstrip("\r\n") - eol = data.find("\r\n") - start_line = data[:eol] + # RFC 7230 section allows for both CRLF and bare LF. + eol = data.find("\n") + start_line = data[:eol].rstrip("\r") try: headers = httputil.HTTPHeaders.parse(data[eol:]) except ValueError: diff --git a/tornado/httpclient.py b/tornado/httpclient.py index df4295171eec8de6efdeed2bd29b9abcd8aaa89a..0ae9e4802fba353cd62e5854813232f88540561f 100644 --- a/tornado/httpclient.py +++ b/tornado/httpclient.py @@ -95,7 +95,8 @@ class HTTPClient(object): If it is a string, we construct an `HTTPRequest` using any additional kwargs: ``HTTPRequest(request, **kwargs)`` - If an error occurs during the fetch, we raise an `HTTPError`. + If an error occurs during the fetch, we raise an `HTTPError` unless + the ``raise_error`` keyword argument is set to False. """ response = self._io_loop.run_sync(functools.partial( self._async_client.fetch, request, **kwargs)) @@ -136,6 +137,9 @@ class AsyncHTTPClient(Configurable): # or with force_instance: client = AsyncHTTPClient(force_instance=True, defaults=dict(user_agent="MyUserAgent")) + + .. versionchanged:: 4.1 + The ``io_loop`` argument is deprecated. """ @classmethod def configurable_base(cls): @@ -200,7 +204,7 @@ class AsyncHTTPClient(Configurable): raise RuntimeError("inconsistent AsyncHTTPClient cache") del self._instance_cache[self.io_loop] - def fetch(self, request, callback=None, **kwargs): + def fetch(self, request, callback=None, raise_error=True, **kwargs): """Executes a request, asynchronously returning an `HTTPResponse`. The request may be either a string URL or an `HTTPRequest` object. @@ -208,8 +212,10 @@ class AsyncHTTPClient(Configurable): kwargs: ``HTTPRequest(request, **kwargs)`` This method returns a `.Future` whose result is an - `HTTPResponse`. The ``Future`` will raise an `HTTPError` if - the request returned a non-200 response code. + `HTTPResponse`. By default, the ``Future`` will raise an `HTTPError` + if the request returned a non-200 response code. Instead, if + ``raise_error`` is set to False, the response will always be + returned regardless of the response code. If a ``callback`` is given, it will be invoked with the `HTTPResponse`. In the callback interface, `HTTPError` is not automatically raised. @@ -243,7 +249,7 @@ class AsyncHTTPClient(Configurable): future.add_done_callback(handle_future) def handle_response(response): - if response.error: + if raise_error and response.error: future.set_exception(response.error) else: future.set_result(response) diff --git a/tornado/httpserver.py b/tornado/httpserver.py index 05d0e186c31e789a8df547ff14e7af787b637bbe..e470e0e7d153418a940ccb4526007374f783ae90 100644 --- a/tornado/httpserver.py +++ b/tornado/httpserver.py @@ -42,30 +42,10 @@ from tornado.tcpserver import TCPServer class HTTPServer(TCPServer, httputil.HTTPServerConnectionDelegate): r"""A non-blocking, single-threaded HTTP server. - A server is defined by either a request callback that takes a - `.HTTPServerRequest` as an argument or a `.HTTPServerConnectionDelegate` - instance. - - A simple example server that echoes back the URI you requested:: - - import tornado.httpserver - import tornado.ioloop - from tornado import httputil - - def handle_request(request): - message = "You requested %s\n" % request.uri - request.connection.write_headers( - httputil.ResponseStartLine('HTTP/1.1', 200, 'OK'), - {"Content-Length": str(len(message))}) - request.connection.write(message) - request.connection.finish() - - http_server = tornado.httpserver.HTTPServer(handle_request) - http_server.listen(8888) - tornado.ioloop.IOLoop.instance().start() - - Applications should use the methods of `.HTTPConnection` to write - their response. + A server is defined by a subclass of `.HTTPServerConnectionDelegate`, + or, for backwards compatibility, a callback that takes an + `.HTTPServerRequest` as an argument. The delegate is usually a + `tornado.web.Application`. `HTTPServer` supports keep-alive connections by default (automatically for HTTP/1.1, or for HTTP/1.0 when the client @@ -134,6 +114,11 @@ class HTTPServer(TCPServer, httputil.HTTPServerConnectionDelegate): ``idle_connection_timeout``, ``body_timeout``, ``max_body_size`` arguments. Added support for `.HTTPServerConnectionDelegate` instances as ``request_callback``. + + .. versionchanged:: 4.1 + `.HTTPServerConnectionDelegate.start_request` is now called with + two arguments ``(server_conn, request_conn)`` (in accordance with the + documentation) instead of one ``(request_conn)``. """ def __init__(self, request_callback, no_keep_alive=False, io_loop=None, xheaders=False, ssl_options=None, protocol=None, @@ -173,7 +158,7 @@ class HTTPServer(TCPServer, httputil.HTTPServerConnectionDelegate): conn.start_serving(self) def start_request(self, server_conn, request_conn): - return _ServerRequestAdapter(self, request_conn) + return _ServerRequestAdapter(self, server_conn, request_conn) def on_close(self, server_conn): self._connections.remove(server_conn) @@ -246,13 +231,14 @@ class _ServerRequestAdapter(httputil.HTTPMessageDelegate): """Adapts the `HTTPMessageDelegate` interface to the interface expected by our clients. """ - def __init__(self, server, connection): + def __init__(self, server, server_conn, request_conn): self.server = server - self.connection = connection + self.connection = request_conn self.request = None if isinstance(server.request_callback, httputil.HTTPServerConnectionDelegate): - self.delegate = server.request_callback.start_request(connection) + self.delegate = server.request_callback.start_request( + server_conn, request_conn) self._chunks = None else: self.delegate = None diff --git a/tornado/httputil.py b/tornado/httputil.py index f5c9c04fea3e197dafaca2329a797b7ee311fb4a..9c99b3efa8ec820669dc7198825d295b693ce578 100644 --- a/tornado/httputil.py +++ b/tornado/httputil.py @@ -62,6 +62,11 @@ except ImportError: pass +# RFC 7230 section 3.5: a recipient MAY recognize a single LF as a line +# terminator and ignore any preceding CR. +_CRLF_RE = re.compile(r'\r?\n') + + class _NormalizedHeaderCache(dict): """Dynamic cached mapping of header names to Http-Header-Case. @@ -193,7 +198,7 @@ class HTTPHeaders(dict): [('Content-Length', '42'), ('Content-Type', 'text/html')] """ h = cls() - for line in headers.splitlines(): + for line in _CRLF_RE.split(headers): if line: h.parse_line(line) return h @@ -331,7 +336,7 @@ class HTTPServerRequest(object): self.uri = uri self.version = version self.headers = headers or HTTPHeaders() - self.body = body or "" + self.body = body or b"" # set remote IP and protocol context = getattr(connection, 'context', None) @@ -543,6 +548,8 @@ class HTTPConnection(object): headers. :arg callback: a callback to be run when the write is complete. + The ``version`` field of ``start_line`` is ignored. + Returns a `.Future` if no callback is given. """ raise NotImplementedError() @@ -689,14 +696,17 @@ def parse_body_arguments(content_type, body, arguments, files, headers=None): if values: arguments.setdefault(name, []).extend(values) elif content_type.startswith("multipart/form-data"): - fields = content_type.split(";") - for field in fields: - k, sep, v = field.strip().partition("=") - if k == "boundary" and v: - parse_multipart_form_data(utf8(v), body, arguments, files) - break - else: - gen_log.warning("Invalid multipart/form-data") + try: + fields = content_type.split(";") + for field in fields: + k, sep, v = field.strip().partition("=") + if k == "boundary" and v: + parse_multipart_form_data(utf8(v), body, arguments, files) + break + else: + raise ValueError("multipart boundary not found") + except Exception as e: + gen_log.warning("Invalid multipart/form-data: %s", e) def parse_multipart_form_data(boundary, data, arguments, files): @@ -782,7 +792,7 @@ def parse_request_start_line(line): method, path, version = line.split(" ") except ValueError: raise HTTPInputError("Malformed HTTP request line") - if not version.startswith("HTTP/"): + if not re.match(r"^HTTP/1\.[0-9]$", version): raise HTTPInputError( "Malformed HTTP version in HTTP Request-Line: %r" % version) return RequestStartLine(method, path, version) @@ -801,7 +811,7 @@ def parse_response_start_line(line): ResponseStartLine(version='HTTP/1.1', code=200, reason='OK') """ line = native_str(line) - match = re.match("(HTTP/1.[01]) ([0-9]+) ([^\r]*)", line) + match = re.match("(HTTP/1.[0-9]) ([0-9]+) ([^\r]*)", line) if not match: raise HTTPInputError("Error parsing response start line") return ResponseStartLine(match.group(1), int(match.group(2)), @@ -873,3 +883,19 @@ def _encode_header(key, pdict): def doctests(): import doctest return doctest.DocTestSuite() + +def split_host_and_port(netloc): + """Returns ``(host, port)`` tuple from ``netloc``. + + Returned ``port`` will be ``None`` if not present. + + .. versionadded:: 4.1 + """ + match = re.match(r'^(.+):(\d+)$', netloc) + if match: + host = match.group(1) + port = int(match.group(2)) + else: + host = netloc + port = None + return (host, port) diff --git a/tornado/ioloop.py b/tornado/ioloop.py index 03193865bc0c811e089068ca352b33bbf7720c85..680dc4016a40f7bfea92a631ddfd8767aca5e6df 100644 --- a/tornado/ioloop.py +++ b/tornado/ioloop.py @@ -167,28 +167,26 @@ class IOLoop(Configurable): del IOLoop._instance @staticmethod - def current(): + def current(instance=True): """Returns the current thread's `IOLoop`. - If an `IOLoop` is currently running or has been marked as current - by `make_current`, returns that instance. Otherwise returns - `IOLoop.instance()`, i.e. the main thread's `IOLoop`. - - A common pattern for classes that depend on ``IOLoops`` is to use - a default argument to enable programs with multiple ``IOLoops`` - but not require the argument for simpler applications:: - - class MyClass(object): - def __init__(self, io_loop=None): - self.io_loop = io_loop or IOLoop.current() + If an `IOLoop` is currently running or has been marked as + current by `make_current`, returns that instance. If there is + no current `IOLoop`, returns `IOLoop.instance()` (i.e. the + main thread's `IOLoop`, creating one if necessary) if ``instance`` + is true. In general you should use `IOLoop.current` as the default when constructing an asynchronous object, and use `IOLoop.instance` when you mean to communicate to the main thread from a different one. + + .. versionchanged:: 4.1 + Added ``instance`` argument to control the + """ current = getattr(IOLoop._current, "instance", None) - if current is None: + if current is None and instance: return IOLoop.instance() return current @@ -200,6 +198,10 @@ class IOLoop(Configurable): `make_current` explicitly before starting the `IOLoop`, so that code run at startup time can find the right instance. + + .. versionchanged:: 4.1 + An `IOLoop` created while there is no current `IOLoop` + will automatically become current. """ IOLoop._current.instance = self @@ -224,7 +226,8 @@ class IOLoop(Configurable): return SelectIOLoop def initialize(self): - pass + if IOLoop.current(instance=False) is None: + self.make_current() def close(self, all_fds=False): """Closes the `IOLoop`, freeing any resources used. @@ -946,6 +949,9 @@ class PeriodicCallback(object): The callback is called every ``callback_time`` milliseconds. `start` must be called after the `PeriodicCallback` is created. + + .. versionchanged:: 4.1 + The ``io_loop`` argument is deprecated. """ def __init__(self, callback, callback_time, io_loop=None): self.callback = callback @@ -969,6 +975,13 @@ class PeriodicCallback(object): self.io_loop.remove_timeout(self._timeout) self._timeout = None + def is_running(self): + """Return True if this `.PeriodicCallback` has been started. + + .. versionadded:: 4.1 + """ + return self._running + def _run(self): if not self._running: return diff --git a/tornado/iostream.py b/tornado/iostream.py index 772aa4dbcb4861df017420463a04a4124f5d7329..cdb6250b9055fed5bb86fab08f3eb87b82b00cf1 100644 --- a/tornado/iostream.py +++ b/tornado/iostream.py @@ -68,6 +68,14 @@ _ERRNO_CONNRESET = (errno.ECONNRESET, errno.ECONNABORTED, errno.EPIPE, if hasattr(errno, "WSAECONNRESET"): _ERRNO_CONNRESET += (errno.WSAECONNRESET, errno.WSAECONNABORTED, errno.WSAETIMEDOUT) +if sys.platform == 'darwin': + # OSX appears to have a race condition that causes send(2) to return + # EPROTOTYPE if called while a socket is being torn down: + # http://erickt.github.io/blog/2014/11/19/adventures-in-debugging-a-potential-osx-kernel-bug/ + # Since the socket is being closed anyway, treat this as an ECONNRESET + # instead of an unexpected error. + _ERRNO_CONNRESET += (errno.EPROTOTYPE,) + # More non-portable errnos: _ERRNO_INPROGRESS = (errno.EINPROGRESS,) @@ -122,6 +130,7 @@ class BaseIOStream(object): """`BaseIOStream` constructor. :arg io_loop: The `.IOLoop` to use; defaults to `.IOLoop.current`. + Deprecated since Tornado 4.1. :arg max_buffer_size: Maximum amount of incoming data to buffer; defaults to 100MB. :arg read_chunk_size: Amount of data to read at one time from the @@ -230,6 +239,12 @@ class BaseIOStream(object): gen_log.info("Unsatisfiable read, closing connection: %s" % e) self.close(exc_info=True) return future + except: + if future is not None: + # Ensure that the future doesn't log an error because its + # failure was never examined. + future.add_done_callback(lambda f: f.exception()) + raise return future def read_until(self, delimiter, callback=None, max_bytes=None): @@ -257,6 +272,10 @@ class BaseIOStream(object): gen_log.info("Unsatisfiable read, closing connection: %s" % e) self.close(exc_info=True) return future + except: + if future is not None: + future.add_done_callback(lambda f: f.exception()) + raise return future def read_bytes(self, num_bytes, callback=None, streaming_callback=None, @@ -281,7 +300,12 @@ class BaseIOStream(object): self._read_bytes = num_bytes self._read_partial = partial self._streaming_callback = stack_context.wrap(streaming_callback) - self._try_inline_read() + try: + self._try_inline_read() + except: + if future is not None: + future.add_done_callback(lambda f: f.exception()) + raise return future def read_until_close(self, callback=None, streaming_callback=None): @@ -305,7 +329,11 @@ class BaseIOStream(object): self._run_read_callback(self._read_buffer_size, False) return future self._read_until_close = True - self._try_inline_read() + try: + self._try_inline_read() + except: + future.add_done_callback(lambda f: f.exception()) + raise return future def write(self, data, callback=None): @@ -331,7 +359,7 @@ class BaseIOStream(object): if data: if (self.max_write_buffer_size is not None and self._write_buffer_size + len(data) > self.max_write_buffer_size): - raise StreamBufferFullError("Reached maximum read buffer size") + raise StreamBufferFullError("Reached maximum write buffer size") # Break up large contiguous strings before inserting them in the # write buffer, so we don't have to recopy the entire thing # as we slice off pieces to send to the socket. @@ -344,6 +372,7 @@ class BaseIOStream(object): future = None else: future = self._write_future = TracebackFuture() + future.add_done_callback(lambda f: f.exception()) if not self._connecting: self._handle_write() if self._write_buffer: @@ -934,9 +963,8 @@ class IOStream(BaseIOStream): return self.socket def close_fd(self): - if self.socket is not None: - self.socket.close() - self.socket = None + self.socket.close() + self.socket = None def get_fd_error(self): errno = self.socket.getsockopt(socket.SOL_SOCKET, @@ -1011,8 +1039,9 @@ class IOStream(BaseIOStream): # reported later in _handle_connect. if (errno_from_exception(e) not in _ERRNO_INPROGRESS and errno_from_exception(e) not in _ERRNO_WOULDBLOCK): - gen_log.warning("Connect error on fd %s: %s", - self.socket.fileno(), e) + if future is None: + gen_log.warning("Connect error on fd %s: %s", + self.socket.fileno(), e) self.close(exc_info=True) return future self._add_io_state(self.io_loop.WRITE) @@ -1059,7 +1088,9 @@ class IOStream(BaseIOStream): socket = self.socket self.io_loop.remove_handler(socket) self.socket = None - socket = ssl_wrap_socket(socket, ssl_options, server_side=server_side, + socket = ssl_wrap_socket(socket, ssl_options, + server_hostname=server_hostname, + server_side=server_side, do_handshake_on_connect=False) orig_close_callback = self._close_callback self._close_callback = None diff --git a/tornado/netutil.py b/tornado/netutil.py index f147c974d69d1a7dc8679a92d785790494c0eb87..17e9580405d664b7e9d4cc0b13ac30c405fa63f4 100644 --- a/tornado/netutil.py +++ b/tornado/netutil.py @@ -20,7 +20,7 @@ from __future__ import absolute_import, division, print_function, with_statement import errno import os -import platform +import sys import socket import stat @@ -105,7 +105,7 @@ def bind_sockets(port, address=None, family=socket.AF_UNSPEC, for res in set(socket.getaddrinfo(address, port, family, socket.SOCK_STREAM, 0, flags)): af, socktype, proto, canonname, sockaddr = res - if (platform.system() == 'Darwin' and address == 'localhost' and + if (sys.platform == 'darwin' and address == 'localhost' and af == socket.AF_INET6 and sockaddr[3] != 0): # Mac OS X includes a link-local address fe80::1%lo0 in the # getaddrinfo results for 'localhost'. However, the firewall @@ -187,6 +187,9 @@ def add_accept_handler(sock, callback, io_loop=None): address of the other end of the connection). Note that this signature is different from the ``callback(fd, events)`` signature used for `.IOLoop` handlers. + + .. versionchanged:: 4.1 + The ``io_loop`` argument is deprecated. """ if io_loop is None: io_loop = IOLoop.current() @@ -301,6 +304,9 @@ class ExecutorResolver(Resolver): The executor will be shut down when the resolver is closed unless ``close_resolver=False``; use this if you want to reuse the same executor elsewhere. + + .. versionchanged:: 4.1 + The ``io_loop`` argument is deprecated. """ def initialize(self, io_loop=None, executor=None, close_executor=True): self.io_loop = io_loop or IOLoop.current() diff --git a/tornado/options.py b/tornado/options.py index 5e23e2910408c8d0483dec4d52fde503a96388a2..c855407c295ea475f864eae9497402192aa24a14 100644 --- a/tornado/options.py +++ b/tornado/options.py @@ -204,6 +204,13 @@ class OptionParser(object): (name, self._options[name].file_name)) frame = sys._getframe(0) options_file = frame.f_code.co_filename + + # Can be called directly, or through top level define() fn, in which + # case, step up above that frame to look for real caller. + if (frame.f_back.f_code.co_filename == options_file and + frame.f_back.f_code.co_name == 'define'): + frame = frame.f_back + file_name = frame.f_back.f_code.co_filename if file_name == options_file: file_name = "" diff --git a/tornado/platform/asyncio.py b/tornado/platform/asyncio.py index dd6722a49d2f7c987878c7ce8484bf7a92cbcd77..bc6851750ac119e8e22de79e9a5ffa3bc000dde5 100644 --- a/tornado/platform/asyncio.py +++ b/tornado/platform/asyncio.py @@ -12,6 +12,8 @@ unfinished callbacks on the event loop that fail when it resumes) from __future__ import absolute_import, division, print_function, with_statement import functools +import tornado.concurrent +from tornado.gen import convert_yielded from tornado.ioloop import IOLoop from tornado import stack_context @@ -138,3 +140,18 @@ class AsyncIOLoop(BaseAsyncIOLoop): def initialize(self): super(AsyncIOLoop, self).initialize(asyncio.new_event_loop(), close_loop=True) + +def to_tornado_future(asyncio_future): + """Convert an ``asyncio.Future`` to a `tornado.concurrent.Future`.""" + tf = tornado.concurrent.Future() + tornado.concurrent.chain_future(asyncio_future, tf) + return tf + +def to_asyncio_future(tornado_future): + """Convert a `tornado.concurrent.Future` to an ``asyncio.Future``.""" + af = asyncio.Future() + tornado.concurrent.chain_future(tornado_future, af) + return af + +if hasattr(convert_yielded, 'register'): + convert_yielded.register(asyncio.Future, to_tornado_future) diff --git a/tornado/platform/caresresolver.py b/tornado/platform/caresresolver.py index c4648c2226903a2a9ff07e3a63d5f8c3bb81b762..5559614f596b8316d684cee4d71e223235a2a14e 100644 --- a/tornado/platform/caresresolver.py +++ b/tornado/platform/caresresolver.py @@ -18,6 +18,9 @@ class CaresResolver(Resolver): so it is only recommended for use in ``AF_INET`` (i.e. IPv4). This is the default for ``tornado.simple_httpclient``, but other libraries may default to ``AF_UNSPEC``. + + .. versionchanged:: 4.1 + The ``io_loop`` argument is deprecated. """ def initialize(self, io_loop=None): self.io_loop = io_loop or IOLoop.current() diff --git a/tornado/platform/kqueue.py b/tornado/platform/kqueue.py index de8c046d3ed4be2109d8eae413d5a096687b02a6..f8f3e4a6113ee0901df45b15ae67a16b3cd2a7e2 100644 --- a/tornado/platform/kqueue.py +++ b/tornado/platform/kqueue.py @@ -54,8 +54,7 @@ class _KQueue(object): if events & IOLoop.WRITE: kevents.append(select.kevent( fd, filter=select.KQ_FILTER_WRITE, flags=flags)) - if events & IOLoop.READ or not kevents: - # Always read when there is not a write + if events & IOLoop.READ: kevents.append(select.kevent( fd, filter=select.KQ_FILTER_READ, flags=flags)) # Even though control() takes a list, it seems to return EINVAL diff --git a/tornado/platform/select.py b/tornado/platform/select.py index 9a879562651aff2f3fa60618173075627da582e7..1e1265547ce7f396f070e1125ec171dcb51c1b80 100644 --- a/tornado/platform/select.py +++ b/tornado/platform/select.py @@ -47,7 +47,7 @@ class _Select(object): # Closed connections are reported as errors by epoll and kqueue, # but as zero-byte reads by select, so when errors are requested # we need to listen for both read and error. - self.read_fds.add(fd) + #self.read_fds.add(fd) def modify(self, fd, events): self.unregister(fd) diff --git a/tornado/platform/twisted.py b/tornado/platform/twisted.py index 27d991cdb3733e06fe332c53f9dcc235ebd08061..09b328366b6796dccaa69f7d40f59794fd814dd5 100644 --- a/tornado/platform/twisted.py +++ b/tornado/platform/twisted.py @@ -70,8 +70,10 @@ import datetime import functools import numbers import socket +import sys import twisted.internet.abstract +from twisted.internet.defer import Deferred from twisted.internet.posixbase import PosixReactorBase from twisted.internet.interfaces import \ IReactorFDSet, IDelayedCall, IReactorTime, IReadDescriptor, IWriteDescriptor @@ -84,6 +86,7 @@ import twisted.names.resolve from zope.interface import implementer +from tornado.concurrent import Future from tornado.escape import utf8 from tornado import gen import tornado.ioloop @@ -147,6 +150,9 @@ class TornadoReactor(PosixReactorBase): We override `mainLoop` instead of `doIteration` and must implement timed call functionality on top of `IOLoop.add_timeout` rather than using the implementation in `PosixReactorBase`. + + .. versionchanged:: 4.1 + The ``io_loop`` argument is deprecated. """ def __init__(self, io_loop=None): if not io_loop: @@ -356,7 +362,11 @@ class _TestReactor(TornadoReactor): def install(io_loop=None): - """Install this package as the default Twisted reactor.""" + """Install this package as the default Twisted reactor. + + .. versionchanged:: 4.1 + The ``io_loop`` argument is deprecated. + """ if not io_loop: io_loop = tornado.ioloop.IOLoop.current() reactor = TornadoReactor(io_loop) @@ -512,6 +522,9 @@ class TwistedResolver(Resolver): ``socket.AF_UNSPEC``. Requires Twisted 12.1 or newer. + + .. versionchanged:: 4.1 + The ``io_loop`` argument is deprecated. """ def initialize(self, io_loop=None): self.io_loop = io_loop or IOLoop.current() @@ -554,3 +567,17 @@ class TwistedResolver(Resolver): (resolved_family, (resolved, port)), ] raise gen.Return(result) + +if hasattr(gen.convert_yielded, 'register'): + @gen.convert_yielded.register(Deferred) + def _(d): + f = Future() + def errback(failure): + try: + failure.raiseException() + # Should never happen, but just in case + raise Exception("errback called without error") + except: + f.set_exc_info(sys.exc_info()) + d.addCallbacks(f.set_result, errback) + return f diff --git a/tornado/process.py b/tornado/process.py index cea3dbd01d54b32d00f28595abc86a97052cbbd9..3790ca0a55f99c591f02aabe18a9acdd9a2f2afc 100644 --- a/tornado/process.py +++ b/tornado/process.py @@ -191,6 +191,9 @@ class Subprocess(object): ``tornado.process.Subprocess.STREAM``, which will make the corresponding attribute of the resulting Subprocess a `.PipeIOStream`. * A new keyword argument ``io_loop`` may be used to pass in an IOLoop. + + .. versionchanged:: 4.1 + The ``io_loop`` argument is deprecated. """ STREAM = object() @@ -263,6 +266,9 @@ class Subprocess(object): Note that the `.IOLoop` used for signal handling need not be the same one used by individual Subprocess objects (as long as the ``IOLoops`` are each running in separate threads). + + .. versionchanged:: 4.1 + The ``io_loop`` argument is deprecated. """ if cls._initialized: return diff --git a/tornado/simple_httpclient.py b/tornado/simple_httpclient.py index cf30f072e6016a68c89b5750e142a69e278d4ebe..31d076e2d114d0840dcf586410aee11eebcaf43f 100644 --- a/tornado/simple_httpclient.py +++ b/tornado/simple_httpclient.py @@ -34,7 +34,7 @@ except ImportError: ssl = None try: - import lib.certifi + import certifi except ImportError: certifi = None @@ -193,12 +193,8 @@ class _HTTPConnection(httputil.HTTPMessageDelegate): netloc = self.parsed.netloc if "@" in netloc: userpass, _, netloc = netloc.rpartition("@") - match = re.match(r'^(.+):(\d+)$', netloc) - if match: - host = match.group(1) - port = int(match.group(2)) - else: - host = netloc + host, port = httputil.split_host_and_port(netloc) + if port is None: port = 443 if self.parsed.scheme == "https" else 80 if re.match(r'^\[.*\]$', host): # raw ipv6 addresses in urls are enclosed in brackets @@ -349,7 +345,7 @@ class _HTTPConnection(httputil.HTTPMessageDelegate): decompress=self.request.decompress_response), self._sockaddr) start_line = httputil.RequestStartLine(self.request.method, - req_path, 'HTTP/1.1') + req_path, '') self.connection.write_headers(start_line, self.request.headers) if self.request.expect_100_continue: self._read_response() diff --git a/tornado/tcpclient.py b/tornado/tcpclient.py index 0abbea200584bf87ef3c9df0951614be8ebb5d19..f594d91b8857e69ddb772ec7e65ae5f18470ae36 100644 --- a/tornado/tcpclient.py +++ b/tornado/tcpclient.py @@ -111,6 +111,7 @@ class _Connector(object): if self.timeout is not None: # If the first attempt failed, don't wait for the # timeout to try an address from the secondary queue. + self.io_loop.remove_timeout(self.timeout) self.on_timeout() return self.clear_timeout() @@ -135,6 +136,9 @@ class _Connector(object): class TCPClient(object): """A non-blocking TCP connection factory. + + .. versionchanged:: 4.1 + The ``io_loop`` argument is deprecated. """ def __init__(self, resolver=None, io_loop=None): self.io_loop = io_loop or IOLoop.current() diff --git a/tornado/tcpserver.py b/tornado/tcpserver.py index 427acec5758f2728279a1e64c30dca23f91262c6..a02b36ffffda457172f7c568decea50c1829ea1a 100644 --- a/tornado/tcpserver.py +++ b/tornado/tcpserver.py @@ -95,7 +95,7 @@ class TCPServer(object): self._pending_sockets = [] self._started = False self.max_buffer_size = max_buffer_size - self.read_chunk_size = None + self.read_chunk_size = read_chunk_size # Verify the SSL options. Otherwise we don't get errors until clients # connect. This doesn't verify that the keys are legitimate, but diff --git a/tornado/test/asyncio_test.py b/tornado/test/asyncio_test.py new file mode 100644 index 0000000000000000000000000000000000000000..cb990748f6f45173b982e9d7c0dce4ccbd35fb95 --- /dev/null +++ b/tornado/test/asyncio_test.py @@ -0,0 +1,68 @@ +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +from __future__ import absolute_import, division, print_function, with_statement + +import sys +import textwrap + +from tornado import gen +from tornado.testing import AsyncTestCase, gen_test +from tornado.test.util import unittest + +try: + from tornado.platform.asyncio import asyncio, AsyncIOLoop +except ImportError: + asyncio = None + +skipIfNoSingleDispatch = unittest.skipIf( + gen.singledispatch is None, "singledispatch module not present") + +@unittest.skipIf(asyncio is None, "asyncio module not present") +class AsyncIOLoopTest(AsyncTestCase): + def get_new_ioloop(self): + io_loop = AsyncIOLoop() + asyncio.set_event_loop(io_loop.asyncio_loop) + return io_loop + + def test_asyncio_callback(self): + # Basic test that the asyncio loop is set up correctly. + asyncio.get_event_loop().call_soon(self.stop) + self.wait() + + @skipIfNoSingleDispatch + @gen_test + def test_asyncio_future(self): + # Test that we can yield an asyncio future from a tornado coroutine. + # Without 'yield from', we must wrap coroutines in asyncio.async. + x = yield asyncio.async( + asyncio.get_event_loop().run_in_executor(None, lambda: 42)) + self.assertEqual(x, 42) + + @unittest.skipIf(sys.version_info < (3, 3), + 'PEP 380 not available') + @skipIfNoSingleDispatch + @gen_test + def test_asyncio_yield_from(self): + # Test that we can use asyncio coroutines with 'yield from' + # instead of asyncio.async(). This requires python 3.3 syntax. + global_namespace = dict(globals(), **locals()) + local_namespace = {} + exec(textwrap.dedent(""" + @gen.coroutine + def f(): + event_loop = asyncio.get_event_loop() + x = yield from event_loop.run_in_executor(None, lambda: 42) + return x + """), global_namespace, local_namespace) + result = yield local_namespace['f']() + self.assertEqual(result, 42) diff --git a/tornado/test/gen_test.py b/tornado/test/gen_test.py index a15cdf73a152f970986937d233ecc06f42ecd0d1..5d646d15007a0b100d5f9fa0c00623ba48690372 100644 --- a/tornado/test/gen_test.py +++ b/tornado/test/gen_test.py @@ -838,6 +838,11 @@ class GenCoroutineTest(AsyncTestCase): yield [f('a', gen.moment), f('b', immediate)] self.assertEqual(''.join(calls), 'abbbbbaaaa') + @gen_test + def test_sleep(self): + yield gen.sleep(0.01) + self.finished = True + class GenSequenceHandler(RequestHandler): @asynchronous @@ -1031,7 +1036,7 @@ class WithTimeoutTest(AsyncTestCase): self.io_loop.add_timeout(datetime.timedelta(seconds=0.1), lambda: future.set_result('asdf')) result = yield gen.with_timeout(datetime.timedelta(seconds=3600), - future) + future, io_loop=self.io_loop) self.assertEqual(result, 'asdf') @gen_test @@ -1039,16 +1044,17 @@ class WithTimeoutTest(AsyncTestCase): future = Future() self.io_loop.add_timeout( datetime.timedelta(seconds=0.1), - lambda: future.set_exception(ZeroDivisionError)) + lambda: future.set_exception(ZeroDivisionError())) with self.assertRaises(ZeroDivisionError): - yield gen.with_timeout(datetime.timedelta(seconds=3600), future) + yield gen.with_timeout(datetime.timedelta(seconds=3600), + future, io_loop=self.io_loop) @gen_test def test_already_resolved(self): future = Future() future.set_result('asdf') result = yield gen.with_timeout(datetime.timedelta(seconds=3600), - future) + future, io_loop=self.io_loop) self.assertEqual(result, 'asdf') @unittest.skipIf(futures is None, 'futures module not present') @@ -1066,6 +1072,107 @@ class WithTimeoutTest(AsyncTestCase): yield gen.with_timeout(datetime.timedelta(seconds=3600), executor.submit(lambda: None)) +class WaitIteratorTest(AsyncTestCase): + @gen_test + def test_empty_iterator(self): + g = gen.WaitIterator() + self.assertTrue(g.done(), 'empty generator iterated') + + with self.assertRaises(ValueError): + g = gen.WaitIterator(False, bar=False) + + self.assertEqual(g.current_index, None, "bad nil current index") + self.assertEqual(g.current_future, None, "bad nil current future") + + @gen_test + def test_already_done(self): + f1 = Future() + f2 = Future() + f3 = Future() + f1.set_result(24) + f2.set_result(42) + f3.set_result(84) + + g = gen.WaitIterator(f1, f2, f3) + i = 0 + while not g.done(): + r = yield g.next() + # Order is not guaranteed, but the current implementation + # preserves ordering of already-done Futures. + if i == 0: + self.assertEqual(g.current_index, 0) + self.assertIs(g.current_future, f1) + self.assertEqual(r, 24) + elif i == 1: + self.assertEqual(g.current_index, 1) + self.assertIs(g.current_future, f2) + self.assertEqual(r, 42) + elif i == 2: + self.assertEqual(g.current_index, 2) + self.assertIs(g.current_future, f3) + self.assertEqual(r, 84) + i += 1 + + self.assertEqual(g.current_index, None, "bad nil current index") + self.assertEqual(g.current_future, None, "bad nil current future") + + dg = gen.WaitIterator(f1=f1, f2=f2) + + while not dg.done(): + dr = yield dg.next() + if dg.current_index == "f1": + self.assertTrue(dg.current_future==f1 and dr==24, + "WaitIterator dict status incorrect") + elif dg.current_index == "f2": + self.assertTrue(dg.current_future==f2 and dr==42, + "WaitIterator dict status incorrect") + else: + self.fail("got bad WaitIterator index {}".format( + dg.current_index)) + + i += 1 + + self.assertEqual(dg.current_index, None, "bad nil current index") + self.assertEqual(dg.current_future, None, "bad nil current future") + + def finish_coroutines(self, iteration, futures): + if iteration == 3: + futures[2].set_result(24) + elif iteration == 5: + futures[0].set_exception(ZeroDivisionError()) + elif iteration == 8: + futures[1].set_result(42) + futures[3].set_result(84) + + if iteration < 8: + self.io_loop.add_callback(self.finish_coroutines, iteration+1, futures) + + @gen_test + def test_iterator(self): + futures = [Future(), Future(), Future(), Future()] + + self.finish_coroutines(0, futures) + + g = gen.WaitIterator(*futures) + + i = 0 + while not g.done(): + try: + r = yield g.next() + except ZeroDivisionError: + self.assertIs(g.current_future, futures[0], + 'exception future invalid') + else: + if i == 0: + self.assertEqual(r, 24, 'iterator value incorrect') + self.assertEqual(g.current_index, 2, 'wrong index') + elif i == 2: + self.assertEqual(r, 42, 'iterator value incorrect') + self.assertEqual(g.current_index, 1, 'wrong index') + elif i == 3: + self.assertEqual(r, 84, 'iterator value incorrect') + self.assertEqual(g.current_index, 3, 'wrong index') + i += 1 if __name__ == '__main__': unittest.main() diff --git a/tornado/test/httpclient_test.py b/tornado/test/httpclient_test.py index bfb50d878d05bbbc6926959c2f48955b529b9cf5..875864ac69ff66ab851a8b95819ba6f62e1fc607 100644 --- a/tornado/test/httpclient_test.py +++ b/tornado/test/httpclient_test.py @@ -404,6 +404,11 @@ Transfer-Encoding: chunked self.assertEqual(context.exception.code, 404) self.assertEqual(context.exception.response.code, 404) + @gen_test + def test_future_http_error_no_raise(self): + response = yield self.http_client.fetch(self.get_url('/notfound'), raise_error=False) + self.assertEqual(response.code, 404) + @gen_test def test_reuse_request_from_response(self): # The response.request attribute should be an HTTPRequest, not @@ -543,7 +548,7 @@ class SyncHTTPClientTest(unittest.TestCase): self.server_ioloop.close(all_fds=True) def get_url(self, path): - return 'http://localhost:%d%s' % (self.port, path) + return 'http://127.0.0.1:%d%s' % (self.port, path) def test_sync_client(self): response = self.http_client.fetch(self.get_url('/')) diff --git a/tornado/test/httpserver_test.py b/tornado/test/httpserver_test.py index 156a027bbd235a720b4186bf6975f095b2b260c6..64ef96d459406d7224bbe1e968da62824ffc6291 100644 --- a/tornado/test/httpserver_test.py +++ b/tornado/test/httpserver_test.py @@ -195,14 +195,14 @@ class HTTPConnectionTest(AsyncHTTPTestCase): def get_app(self): return Application(self.get_handlers()) - def raw_fetch(self, headers, body): + def raw_fetch(self, headers, body, newline=b"\r\n"): with closing(IOStream(socket.socket())) as stream: stream.connect(('127.0.0.1', self.get_http_port()), self.stop) self.wait() stream.write( - b"\r\n".join(headers + - [utf8("Content-Length: %d\r\n" % len(body))]) + - b"\r\n" + body) + newline.join(headers + + [utf8("Content-Length: %d" % len(body))]) + + newline + newline + body) read_stream_body(stream, self.stop) headers, body = self.wait() return body @@ -232,12 +232,19 @@ class HTTPConnectionTest(AsyncHTTPTestCase): self.assertEqual(u("\u00f3"), data["filename"]) self.assertEqual(u("\u00fa"), data["filebody"]) + def test_newlines(self): + # We support both CRLF and bare LF as line separators. + for newline in (b"\r\n", b"\n"): + response = self.raw_fetch([b"GET /hello HTTP/1.0"], b"", + newline=newline) + self.assertEqual(response, b'Hello world') + def test_100_continue(self): # Run through a 100-continue interaction by hand: # When given Expect: 100-continue, we get a 100 response after the # headers, and then the real response after the body. stream = IOStream(socket.socket(), io_loop=self.io_loop) - stream.connect(("localhost", self.get_http_port()), callback=self.stop) + stream.connect(("127.0.0.1", self.get_http_port()), callback=self.stop) self.wait() stream.write(b"\r\n".join([b"POST /hello HTTP/1.1", b"Content-Length: 1024", @@ -374,7 +381,7 @@ class HTTPServerRawTest(AsyncHTTPTestCase): def setUp(self): super(HTTPServerRawTest, self).setUp() self.stream = IOStream(socket.socket()) - self.stream.connect(('localhost', self.get_http_port()), self.stop) + self.stream.connect(('127.0.0.1', self.get_http_port()), self.stop) self.wait() def tearDown(self): @@ -555,7 +562,7 @@ class UnixSocketTest(AsyncTestCase): self.stream.write(b"GET /hello HTTP/1.0\r\n\r\n") self.stream.read_until(b"\r\n", self.stop) response = self.wait() - self.assertEqual(response, b"HTTP/1.0 200 OK\r\n") + self.assertEqual(response, b"HTTP/1.1 200 OK\r\n") self.stream.read_until(b"\r\n\r\n", self.stop) headers = HTTPHeaders.parse(self.wait().decode('latin1')) self.stream.read_bytes(int(headers["Content-Length"]), self.stop) @@ -623,13 +630,13 @@ class KeepAliveTest(AsyncHTTPTestCase): # The next few methods are a crude manual http client def connect(self): self.stream = IOStream(socket.socket(), io_loop=self.io_loop) - self.stream.connect(('localhost', self.get_http_port()), self.stop) + self.stream.connect(('127.0.0.1', self.get_http_port()), self.stop) self.wait() def read_headers(self): self.stream.read_until(b'\r\n', self.stop) first_line = self.wait() - self.assertTrue(first_line.startswith(self.http_version + b' 200'), first_line) + self.assertTrue(first_line.startswith(b'HTTP/1.1 200'), first_line) self.stream.read_until(b'\r\n\r\n', self.stop) header_bytes = self.wait() headers = HTTPHeaders.parse(header_bytes.decode('latin1')) @@ -808,8 +815,8 @@ class StreamingChunkSizeTest(AsyncHTTPTestCase): def get_app(self): class App(HTTPServerConnectionDelegate): - def start_request(self, connection): - return StreamingChunkSizeTest.MessageDelegate(connection) + def start_request(self, server_conn, request_conn): + return StreamingChunkSizeTest.MessageDelegate(request_conn) return App() def fetch_chunk_sizes(self, **kwargs): @@ -900,7 +907,7 @@ class IdleTimeoutTest(AsyncHTTPTestCase): def connect(self): stream = IOStream(socket.socket()) - stream.connect(('localhost', self.get_http_port()), self.stop) + stream.connect(('127.0.0.1', self.get_http_port()), self.stop) self.wait() self.streams.append(stream) return stream diff --git a/tornado/test/httputil_test.py b/tornado/test/httputil_test.py index 5ca5cf9f335cc50f63b3ae48356fdcb2d4c954e3..3995abe8b94ab9166e0ed43205b12cf4ef578e3d 100644 --- a/tornado/test/httputil_test.py +++ b/tornado/test/httputil_test.py @@ -3,10 +3,11 @@ from __future__ import absolute_import, division, print_function, with_statement from tornado.httputil import url_concat, parse_multipart_form_data, HTTPHeaders, format_timestamp, HTTPServerRequest, parse_request_start_line -from tornado.escape import utf8 +from tornado.escape import utf8, native_str from tornado.log import gen_log from tornado.testing import ExpectLog from tornado.test.util import unittest +from tornado.util import u import datetime import logging @@ -228,6 +229,57 @@ Foo: even ("Foo", "bar baz"), ("Foo", "even more lines")]) + def test_unicode_newlines(self): + # Ensure that only \r\n is recognized as a header separator, and not + # the other newline-like unicode characters. + # Characters that are likely to be problematic can be found in + # http://unicode.org/standard/reports/tr13/tr13-5.html + # and cpython's unicodeobject.c (which defines the implementation + # of unicode_type.splitlines(), and uses a different list than TR13). + newlines = [ + u('\u001b'), # VERTICAL TAB + u('\u001c'), # FILE SEPARATOR + u('\u001d'), # GROUP SEPARATOR + u('\u001e'), # RECORD SEPARATOR + u('\u0085'), # NEXT LINE + u('\u2028'), # LINE SEPARATOR + u('\u2029'), # PARAGRAPH SEPARATOR + ] + for newline in newlines: + # Try the utf8 and latin1 representations of each newline + for encoding in ['utf8', 'latin1']: + try: + try: + encoded = newline.encode(encoding) + except UnicodeEncodeError: + # Some chars cannot be represented in latin1 + continue + data = b'Cookie: foo=' + encoded + b'bar' + # parse() wants a native_str, so decode through latin1 + # in the same way the real parser does. + headers = HTTPHeaders.parse( + native_str(data.decode('latin1'))) + expected = [('Cookie', 'foo=' + + native_str(encoded.decode('latin1')) + 'bar')] + self.assertEqual( + expected, list(headers.get_all())) + except Exception: + gen_log.warning("failed while trying %r in %s", + newline, encoding) + raise + + def test_optional_cr(self): + # Both CRLF and LF should be accepted as separators. CR should not be + # part of the data when followed by LF, but it is a normal char + # otherwise (or should bare CR be an error?) + headers = HTTPHeaders.parse( + 'CRLF: crlf\r\nLF: lf\nCR: cr\rMore: more\r\n') + self.assertEqual(sorted(headers.get_all()), + [('Cr', 'cr\rMore: more'), + ('Crlf', 'crlf'), + ('Lf', 'lf'), + ]) + class FormatTimestampTest(unittest.TestCase): # Make sure that all the input types are supported. @@ -264,6 +316,10 @@ class HTTPServerRequestTest(unittest.TestCase): # more required parameters slip in. HTTPServerRequest(uri='/') + def test_body_is_a_byte_string(self): + requets = HTTPServerRequest(uri='/') + self.assertIsInstance(requets.body, bytes) + class ParseRequestStartLineTest(unittest.TestCase): METHOD = "GET" diff --git a/tornado/test/iostream_test.py b/tornado/test/iostream_test.py index f51caeaf7faec1c35c522e078c2c22a55d048add..ca35de69bbae0b48467642c4b157e17f3878cab1 100644 --- a/tornado/test/iostream_test.py +++ b/tornado/test/iostream_test.py @@ -8,9 +8,9 @@ from tornado.log import gen_log, app_log from tornado.netutil import ssl_wrap_socket from tornado.stack_context import NullContext from tornado.testing import AsyncHTTPTestCase, AsyncHTTPSTestCase, AsyncTestCase, bind_unused_port, ExpectLog, gen_test -from tornado.test.util import unittest, skipIfNonUnix +from tornado.test.util import unittest, skipIfNonUnix, refusing_port from tornado.web import RequestHandler, Application -import lib.certifi +import certifi import errno import logging import os @@ -51,18 +51,18 @@ class TestIOStreamWebMixin(object): def test_read_until_close(self): stream = self._make_client_iostream() - stream.connect(('localhost', self.get_http_port()), callback=self.stop) + stream.connect(('127.0.0.1', self.get_http_port()), callback=self.stop) self.wait() stream.write(b"GET / HTTP/1.0\r\n\r\n") stream.read_until_close(self.stop) data = self.wait() - self.assertTrue(data.startswith(b"HTTP/1.0 200")) + self.assertTrue(data.startswith(b"HTTP/1.1 200")) self.assertTrue(data.endswith(b"Hello")) def test_read_zero_bytes(self): self.stream = self._make_client_iostream() - self.stream.connect(("localhost", self.get_http_port()), + self.stream.connect(("127.0.0.1", self.get_http_port()), callback=self.stop) self.wait() self.stream.write(b"GET / HTTP/1.0\r\n\r\n") @@ -70,7 +70,7 @@ class TestIOStreamWebMixin(object): # normal read self.stream.read_bytes(9, self.stop) data = self.wait() - self.assertEqual(data, b"HTTP/1.0 ") + self.assertEqual(data, b"HTTP/1.1 ") # zero bytes self.stream.read_bytes(0, self.stop) @@ -91,7 +91,7 @@ class TestIOStreamWebMixin(object): def connected_callback(): connected[0] = True self.stop() - stream.connect(("localhost", self.get_http_port()), + stream.connect(("127.0.0.1", self.get_http_port()), callback=connected_callback) # unlike the previous tests, try to write before the connection # is complete. @@ -121,11 +121,11 @@ class TestIOStreamWebMixin(object): """Basic test of IOStream's ability to return Futures.""" stream = self._make_client_iostream() connect_result = yield stream.connect( - ("localhost", self.get_http_port())) + ("127.0.0.1", self.get_http_port())) self.assertIs(connect_result, stream) yield stream.write(b"GET / HTTP/1.0\r\n\r\n") first_line = yield stream.read_until(b"\r\n") - self.assertEqual(first_line, b"HTTP/1.0 200 OK\r\n") + self.assertEqual(first_line, b"HTTP/1.1 200 OK\r\n") # callback=None is equivalent to no callback. header_data = yield stream.read_until(b"\r\n\r\n", callback=None) headers = HTTPHeaders.parse(header_data.decode('latin1')) @@ -137,7 +137,7 @@ class TestIOStreamWebMixin(object): @gen_test def test_future_close_while_reading(self): stream = self._make_client_iostream() - yield stream.connect(("localhost", self.get_http_port())) + yield stream.connect(("127.0.0.1", self.get_http_port())) yield stream.write(b"GET / HTTP/1.0\r\n\r\n") with self.assertRaises(StreamClosedError): yield stream.read_bytes(1024 * 1024) @@ -147,7 +147,7 @@ class TestIOStreamWebMixin(object): def test_future_read_until_close(self): # Ensure that the data comes through before the StreamClosedError. stream = self._make_client_iostream() - yield stream.connect(("localhost", self.get_http_port())) + yield stream.connect(("127.0.0.1", self.get_http_port())) yield stream.write(b"GET / HTTP/1.0\r\nConnection: close\r\n\r\n") yield stream.read_until(b"\r\n\r\n") body = yield stream.read_until_close() @@ -217,17 +217,18 @@ class TestIOStreamMixin(object): # When a connection is refused, the connect callback should not # be run. (The kqueue IOLoop used to behave differently from the # epoll IOLoop in this respect) - server_socket, port = bind_unused_port() - server_socket.close() + cleanup_func, port = refusing_port() + self.addCleanup(cleanup_func) stream = IOStream(socket.socket(), self.io_loop) self.connect_called = False def connect_callback(): self.connect_called = True + self.stop() stream.set_close_callback(self.stop) # log messages vary by platform and ioloop implementation with ExpectLog(gen_log, ".*", required=False): - stream.connect(("localhost", port), connect_callback) + stream.connect(("127.0.0.1", port), connect_callback) self.wait() self.assertFalse(self.connect_called) self.assertTrue(isinstance(stream.error, socket.error), stream.error) @@ -248,7 +249,8 @@ class TestIOStreamMixin(object): # opendns and some ISPs return bogus addresses for nonexistent # domains instead of the proper error codes). with ExpectLog(gen_log, "Connect error"): - stream.connect(('an invalid domain', 54321)) + stream.connect(('an invalid domain', 54321), callback=self.stop) + self.wait() self.assertTrue(isinstance(stream.error, socket.gaierror), stream.error) def test_read_callback_error(self): @@ -724,6 +726,26 @@ class TestIOStreamMixin(object): server.close() client.close() + def test_flow_control(self): + MB = 1024 * 1024 + server, client = self.make_iostream_pair(max_buffer_size=5 * MB) + try: + # Client writes more than the server will accept. + client.write(b"a" * 10 * MB) + # The server pauses while reading. + server.read_bytes(MB, self.stop) + self.wait() + self.io_loop.call_later(0.1, self.stop) + self.wait() + # The client's writes have been blocked; the server can + # continue to read gradually. + for i in range(9): + server.read_bytes(MB, self.stop) + self.wait() + finally: + server.close() + client.close() + class TestIOStreamWebHTTP(TestIOStreamWebMixin, AsyncHTTPTestCase): def _make_client_iostream(self): @@ -820,10 +842,10 @@ class TestIOStreamStartTLS(AsyncTestCase): recv_line = yield self.client_stream.read_until(b"\r\n") self.assertEqual(line, recv_line) - def client_start_tls(self, ssl_options=None): + def client_start_tls(self, ssl_options=None, server_hostname=None): client_stream = self.client_stream self.client_stream = None - return client_stream.start_tls(False, ssl_options) + return client_stream.start_tls(False, ssl_options, server_hostname) def server_start_tls(self, ssl_options=None): server_stream = self.server_stream @@ -853,12 +875,32 @@ class TestIOStreamStartTLS(AsyncTestCase): @gen_test def test_handshake_fail(self): - self.server_start_tls(_server_ssl_options()) + server_future = self.server_start_tls(_server_ssl_options()) client_future = self.client_start_tls( dict(cert_reqs=ssl.CERT_REQUIRED, ca_certs=certifi.where())) with ExpectLog(gen_log, "SSL Error"): with self.assertRaises(ssl.SSLError): yield client_future + with self.assertRaises((ssl.SSLError, socket.error)): + yield server_future + + + @unittest.skipIf(not hasattr(ssl, 'create_default_context'), + 'ssl.create_default_context not present') + @gen_test + def test_check_hostname(self): + # Test that server_hostname parameter to start_tls is being used. + # The check_hostname functionality is only available in python 2.7 and + # up and in python 3.4 and up. + server_future = self.server_start_tls(_server_ssl_options()) + client_future = self.client_start_tls( + ssl.create_default_context(), + server_hostname=b'127.0.0.1') + with ExpectLog(gen_log, "SSL Error"): + with self.assertRaises(ssl.SSLError): + yield client_future + with self.assertRaises((ssl.SSLError, socket.error)): + yield server_future @skipIfNonUnix diff --git a/tornado/test/runtests.py b/tornado/test/runtests.py index a80b80b9268017af0d4c6f2963c8e55500528fa6..acbb5695e2a11bdd90985c2acae82828f5104a1e 100644 --- a/tornado/test/runtests.py +++ b/tornado/test/runtests.py @@ -22,6 +22,7 @@ TEST_MODULES = [ 'tornado.httputil.doctests', 'tornado.iostream.doctests', 'tornado.util.doctests', + 'tornado.test.asyncio_test', 'tornado.test.auth_test', 'tornado.test.concurrent_test', 'tornado.test.curl_httpclient_test', @@ -67,6 +68,21 @@ class TornadoTextTestRunner(unittest.TextTestRunner): return result +class LogCounter(logging.Filter): + """Counts the number of WARNING or higher log records.""" + def __init__(self, *args, **kwargs): + # Can't use super() because logging.Filter is an old-style class in py26 + logging.Filter.__init__(self, *args, **kwargs) + self.warning_count = self.error_count = 0 + + def filter(self, record): + if record.levelno >= logging.ERROR: + self.error_count += 1 + elif record.levelno >= logging.WARNING: + self.warning_count += 1 + return True + + def main(): # The -W command-line option does not work in a virtualenv with # python 3 (as of virtualenv 1.7), so configure warnings @@ -92,6 +108,13 @@ def main(): # 2.7 and 3.2 warnings.filterwarnings("ignore", category=DeprecationWarning, message="Please use assert.* instead") + # unittest2 0.6 on py26 reports these as PendingDeprecationWarnings + # instead of DeprecationWarnings. + warnings.filterwarnings("ignore", category=PendingDeprecationWarning, + message="Please use assert.* instead") + # Twisted 15.0.0 triggers some warnings on py3 with -bb. + warnings.filterwarnings("ignore", category=BytesWarning, + module=r"twisted\..*") logging.getLogger("tornado.access").setLevel(logging.CRITICAL) @@ -121,6 +144,10 @@ def main(): IOLoop.configure(options.ioloop, **kwargs) add_parse_callback(configure_ioloop) + log_counter = LogCounter() + add_parse_callback( + lambda: logging.getLogger().handlers[0].addFilter(log_counter)) + import tornado.testing kwargs = {} if sys.version_info >= (3, 2): @@ -131,7 +158,16 @@ def main(): # detail. http://bugs.python.org/issue15626 kwargs['warnings'] = False kwargs['testRunner'] = TornadoTextTestRunner - tornado.testing.main(**kwargs) + try: + tornado.testing.main(**kwargs) + finally: + # The tests should run clean; consider it a failure if they logged + # any warnings or errors. We'd like to ban info logs too, but + # we can't count them cleanly due to interactions with LogTrapTestCase. + if log_counter.warning_count > 0 or log_counter.error_count > 0: + logging.error("logged %d warnings and %d errors", + log_counter.warning_count, log_counter.error_count) + sys.exit(1) if __name__ == '__main__': main() diff --git a/tornado/test/simple_httpclient_test.py b/tornado/test/simple_httpclient_test.py index e3fab57a2fb1107eccfc118636f3362a974db1a2..bb870db3b0308d109687ff78241b0c3edb7b67bb 100644 --- a/tornado/test/simple_httpclient_test.py +++ b/tornado/test/simple_httpclient_test.py @@ -19,8 +19,8 @@ from tornado.netutil import Resolver, bind_sockets from tornado.simple_httpclient import SimpleAsyncHTTPClient, _default_ca_certs from tornado.test.httpclient_test import ChunkHandler, CountdownHandler, HelloWorldHandler from tornado.test import httpclient_test -from tornado.testing import AsyncHTTPTestCase, AsyncHTTPSTestCase, AsyncTestCase, bind_unused_port, ExpectLog -from tornado.test.util import skipOnTravis, skipIfNoIPv6 +from tornado.testing import AsyncHTTPTestCase, AsyncHTTPSTestCase, AsyncTestCase, ExpectLog +from tornado.test.util import skipOnTravis, skipIfNoIPv6, refusing_port from tornado.web import RequestHandler, Application, asynchronous, url, stream_request_body @@ -235,9 +235,16 @@ class SimpleHTTPClientTestMixin(object): @skipOnTravis def test_request_timeout(self): - response = self.fetch('/trigger?wake=false', request_timeout=0.1) + timeout = 0.1 + timeout_min, timeout_max = 0.099, 0.15 + if os.name == 'nt': + timeout = 0.5 + timeout_min, timeout_max = 0.4, 0.6 + + response = self.fetch('/trigger?wake=false', request_timeout=timeout) self.assertEqual(response.code, 599) - self.assertTrue(0.099 < response.request_time < 0.15, response.request_time) + self.assertTrue(timeout_min < response.request_time < timeout_max, + response.request_time) self.assertEqual(str(response.error), "HTTP 599: Timeout") # trigger the hanging request to let it clean up after itself self.triggers.popleft()() @@ -315,10 +322,10 @@ class SimpleHTTPClientTestMixin(object): self.assertTrue(host_re.match(response.body), response.body) def test_connection_refused(self): - server_socket, port = bind_unused_port() - server_socket.close() + cleanup_func, port = refusing_port() + self.addCleanup(cleanup_func) with ExpectLog(gen_log, ".*", required=False): - self.http_client.fetch("http://localhost:%d/" % port, self.stop) + self.http_client.fetch("http://127.0.0.1:%d/" % port, self.stop) response = self.wait() self.assertEqual(599, response.code) diff --git a/tornado/test/tcpclient_test.py b/tornado/test/tcpclient_test.py index 5df4a7abb30e398eebb247da67bf18ad2a1dcb4b..1a4201e6b7712afdf5a00574868114d1dc7380bf 100644 --- a/tornado/test/tcpclient_test.py +++ b/tornado/test/tcpclient_test.py @@ -24,8 +24,8 @@ from tornado.concurrent import Future from tornado.netutil import bind_sockets, Resolver from tornado.tcpclient import TCPClient, _Connector from tornado.tcpserver import TCPServer -from tornado.testing import AsyncTestCase, bind_unused_port, gen_test -from tornado.test.util import skipIfNoIPv6, unittest +from tornado.testing import AsyncTestCase, gen_test +from tornado.test.util import skipIfNoIPv6, unittest, refusing_port # Fake address families for testing. Used in place of AF_INET # and AF_INET6 because some installations do not have AF_INET6. @@ -120,8 +120,8 @@ class TCPClientTest(AsyncTestCase): @gen_test def test_refused_ipv4(self): - sock, port = bind_unused_port() - sock.close() + cleanup_func, port = refusing_port() + self.addCleanup(cleanup_func) with self.assertRaises(IOError): yield self.client.connect('127.0.0.1', port) diff --git a/tornado/test/twisted_test.py b/tornado/test/twisted_test.py index 2922a61e000507536aa5546cff0748bf7aab4c38..b31ae94cb90e5e4f01be1a005291bac36adb5f53 100644 --- a/tornado/test/twisted_test.py +++ b/tornado/test/twisted_test.py @@ -19,15 +19,18 @@ Unittest for the twisted-style reactor. from __future__ import absolute_import, division, print_function, with_statement +import logging import os import shutil import signal +import sys import tempfile import threading +import warnings try: import fcntl - from twisted.internet.defer import Deferred + from twisted.internet.defer import Deferred, inlineCallbacks, returnValue from twisted.internet.interfaces import IReadDescriptor, IWriteDescriptor from twisted.internet.protocol import Protocol from twisted.python import log @@ -40,10 +43,12 @@ except ImportError: # The core of Twisted 12.3.0 is available on python 3, but twisted.web is not # so test for it separately. try: - from twisted.web.client import Agent + from twisted.web.client import Agent, readBody from twisted.web.resource import Resource from twisted.web.server import Site - have_twisted_web = True + # As of Twisted 15.0.0, twisted.web is present but fails our + # tests due to internal str/bytes errors. + have_twisted_web = sys.version_info < (3,) except ImportError: have_twisted_web = False @@ -52,6 +57,8 @@ try: except ImportError: import _thread as thread # py3 +from tornado.escape import utf8 +from tornado import gen from tornado.httpclient import AsyncHTTPClient from tornado.httpserver import HTTPServer from tornado.ioloop import IOLoop @@ -65,6 +72,8 @@ from tornado.web import RequestHandler, Application skipIfNoTwisted = unittest.skipUnless(have_twisted, "twisted module not present") +skipIfNoSingleDispatch = unittest.skipIf( + gen.singledispatch is None, "singledispatch module not present") def save_signal_handlers(): saved = {} @@ -407,7 +416,7 @@ class CompatibilityTests(unittest.TestCase): # http://twistedmatrix.com/documents/current/web/howto/client.html chunks = [] client = Agent(self.reactor) - d = client.request('GET', url) + d = client.request(b'GET', utf8(url)) class Accumulator(Protocol): def __init__(self, finished): @@ -425,38 +434,98 @@ class CompatibilityTests(unittest.TestCase): return finished d.addCallback(callback) - def shutdown(ignored): - self.stop_loop() + def shutdown(failure): + if hasattr(self, 'stop_loop'): + self.stop_loop() + elif failure is not None: + # loop hasn't been initialized yet; try our best to + # get an error message out. (the runner() interaction + # should probably be refactored). + try: + failure.raiseException() + except: + logging.error('exception before starting loop', exc_info=True) d.addBoth(shutdown) runner() self.assertTrue(chunks) return ''.join(chunks) + def twisted_coroutine_fetch(self, url, runner): + body = [None] + @gen.coroutine + def f(): + # This is simpler than the non-coroutine version, but it cheats + # by reading the body in one blob instead of streaming it with + # a Protocol. + client = Agent(self.reactor) + response = yield client.request(b'GET', utf8(url)) + with warnings.catch_warnings(): + # readBody has a buggy DeprecationWarning in Twisted 15.0: + # https://twistedmatrix.com/trac/changeset/43379 + warnings.simplefilter('ignore', category=DeprecationWarning) + body[0] = yield readBody(response) + self.stop_loop() + self.io_loop.add_callback(f) + runner() + return body[0] + def testTwistedServerTornadoClientIOLoop(self): self.start_twisted_server() response = self.tornado_fetch( - 'http://localhost:%d' % self.twisted_port, self.run_ioloop) + 'http://127.0.0.1:%d' % self.twisted_port, self.run_ioloop) self.assertEqual(response.body, 'Hello from twisted!') def testTwistedServerTornadoClientReactor(self): self.start_twisted_server() response = self.tornado_fetch( - 'http://localhost:%d' % self.twisted_port, self.run_reactor) + 'http://127.0.0.1:%d' % self.twisted_port, self.run_reactor) self.assertEqual(response.body, 'Hello from twisted!') def testTornadoServerTwistedClientIOLoop(self): self.start_tornado_server() response = self.twisted_fetch( - 'http://localhost:%d' % self.tornado_port, self.run_ioloop) + 'http://127.0.0.1:%d' % self.tornado_port, self.run_ioloop) self.assertEqual(response, 'Hello from tornado!') def testTornadoServerTwistedClientReactor(self): self.start_tornado_server() response = self.twisted_fetch( - 'http://localhost:%d' % self.tornado_port, self.run_reactor) + 'http://127.0.0.1:%d' % self.tornado_port, self.run_reactor) + self.assertEqual(response, 'Hello from tornado!') + + @skipIfNoSingleDispatch + def testTornadoServerTwistedCoroutineClientIOLoop(self): + self.start_tornado_server() + response = self.twisted_coroutine_fetch( + 'http://127.0.0.1:%d' % self.tornado_port, self.run_ioloop) self.assertEqual(response, 'Hello from tornado!') +@skipIfNoTwisted +@skipIfNoSingleDispatch +class ConvertDeferredTest(unittest.TestCase): + def test_success(self): + @inlineCallbacks + def fn(): + if False: + # inlineCallbacks doesn't work with regular functions; + # must have a yield even if it's unreachable. + yield + returnValue(42) + f = gen.convert_yielded(fn()) + self.assertEqual(f.result(), 42) + + def test_failure(self): + @inlineCallbacks + def fn(): + if False: + yield + 1 / 0 + f = gen.convert_yielded(fn()) + with self.assertRaises(ZeroDivisionError): + f.result() + + if have_twisted: # Import and run as much of twisted's test suite as possible. # This is unfortunately rather dependent on implementation details, diff --git a/tornado/test/util.py b/tornado/test/util.py index d31bbba33d8019865d1c44abd43eb246fa639c1f..358809f216e247014b2401089c9744a5cde336cc 100644 --- a/tornado/test/util.py +++ b/tornado/test/util.py @@ -4,6 +4,8 @@ import os import socket import sys +from tornado.testing import bind_unused_port + # Encapsulate the choice of unittest or unittest2 here. # To be used as 'from tornado.test.util import unittest'. if sys.version_info < (2, 7): @@ -28,3 +30,22 @@ skipIfNoNetwork = unittest.skipIf('NO_NETWORK' in os.environ, 'network access disabled') skipIfNoIPv6 = unittest.skipIf(not socket.has_ipv6, 'ipv6 support not present') + +def refusing_port(): + """Returns a local port number that will refuse all connections. + + Return value is (cleanup_func, port); the cleanup function + must be called to free the port to be reused. + """ + # On travis-ci, port numbers are reassigned frequently. To avoid + # collisions with other tests, we use an open client-side socket's + # ephemeral port number to ensure that nothing can listen on that + # port. + server_socket, port = bind_unused_port() + server_socket.setblocking(1) + client_socket = socket.socket() + client_socket.connect(("127.0.0.1", port)) + conn, client_addr = server_socket.accept() + conn.close() + server_socket.close() + return (client_socket.close, client_addr[1]) diff --git a/tornado/test/web_test.py b/tornado/test/web_test.py index 55c9c9e8c99075cec964e8173e461cfb3c43d8aa..77ad388812de95585673b261e6b608b8f41a25d3 100644 --- a/tornado/test/web_test.py +++ b/tornado/test/web_test.py @@ -305,7 +305,7 @@ class ConnectionCloseTest(WebTestCase): def test_connection_close(self): s = socket.socket(socket.AF_INET, socket.SOCK_STREAM, 0) - s.connect(("localhost", self.get_http_port())) + s.connect(("127.0.0.1", self.get_http_port())) self.stream = IOStream(s, io_loop=self.io_loop) self.stream.write(b"GET / HTTP/1.0\r\n\r\n") self.wait() @@ -1907,7 +1907,7 @@ class StreamingRequestBodyTest(WebTestCase): def connect(self, url, connection_close): # Use a raw connection so we can control the sending of data. s = socket.socket(socket.AF_INET, socket.SOCK_STREAM, 0) - s.connect(("localhost", self.get_http_port())) + s.connect(("127.0.0.1", self.get_http_port())) stream = IOStream(s, io_loop=self.io_loop) stream.write(b"GET " + url + b" HTTP/1.1\r\n") if connection_close: diff --git a/tornado/test/websocket_test.py b/tornado/test/websocket_test.py index e1e3ea7005d1c396a9f47dcdb65daf4c2fd1aec6..7e93d17141a886a9cb06a5f840a11f7e723b9dae 100644 --- a/tornado/test/websocket_test.py +++ b/tornado/test/websocket_test.py @@ -75,6 +75,7 @@ class NonWebSocketHandler(RequestHandler): class CloseReasonHandler(TestWebSocketHandler): def open(self): + self.on_close_called = False self.close(1001, "goodbye") @@ -91,7 +92,7 @@ class WebSocketBaseTestCase(AsyncHTTPTestCase): @gen.coroutine def ws_connect(self, path, compression_options=None): ws = yield websocket_connect( - 'ws://localhost:%d%s' % (self.get_http_port(), path), + 'ws://127.0.0.1:%d%s' % (self.get_http_port(), path), compression_options=compression_options) raise gen.Return(ws) @@ -135,7 +136,7 @@ class WebSocketTest(WebSocketBaseTestCase): def test_websocket_callbacks(self): websocket_connect( - 'ws://localhost:%d/echo' % self.get_http_port(), + 'ws://127.0.0.1:%d/echo' % self.get_http_port(), io_loop=self.io_loop, callback=self.stop) ws = self.wait().result() ws.write_message('hello') @@ -189,14 +190,14 @@ class WebSocketTest(WebSocketBaseTestCase): with self.assertRaises(IOError): with ExpectLog(gen_log, ".*"): yield websocket_connect( - 'ws://localhost:%d/' % port, + 'ws://127.0.0.1:%d/' % port, io_loop=self.io_loop, connect_timeout=3600) @gen_test def test_websocket_close_buffered_data(self): ws = yield websocket_connect( - 'ws://localhost:%d/echo' % self.get_http_port()) + 'ws://127.0.0.1:%d/echo' % self.get_http_port()) ws.write_message('hello') ws.write_message('world') # Close the underlying stream. @@ -207,7 +208,7 @@ class WebSocketTest(WebSocketBaseTestCase): def test_websocket_headers(self): # Ensure that arbitrary headers can be passed through websocket_connect. ws = yield websocket_connect( - HTTPRequest('ws://localhost:%d/header' % self.get_http_port(), + HTTPRequest('ws://127.0.0.1:%d/header' % self.get_http_port(), headers={'X-Test': 'hello'})) response = yield ws.read_message() self.assertEqual(response, 'hello') @@ -221,6 +222,8 @@ class WebSocketTest(WebSocketBaseTestCase): self.assertIs(msg, None) self.assertEqual(ws.close_code, 1001) self.assertEqual(ws.close_reason, "goodbye") + # The on_close callback is called no matter which side closed. + yield self.close_future @gen_test def test_client_close_reason(self): @@ -243,8 +246,8 @@ class WebSocketTest(WebSocketBaseTestCase): def test_check_origin_valid_no_path(self): port = self.get_http_port() - url = 'ws://localhost:%d/echo' % port - headers = {'Origin': 'http://localhost:%d' % port} + url = 'ws://127.0.0.1:%d/echo' % port + headers = {'Origin': 'http://127.0.0.1:%d' % port} ws = yield websocket_connect(HTTPRequest(url, headers=headers), io_loop=self.io_loop) @@ -257,8 +260,8 @@ class WebSocketTest(WebSocketBaseTestCase): def test_check_origin_valid_with_path(self): port = self.get_http_port() - url = 'ws://localhost:%d/echo' % port - headers = {'Origin': 'http://localhost:%d/something' % port} + url = 'ws://127.0.0.1:%d/echo' % port + headers = {'Origin': 'http://127.0.0.1:%d/something' % port} ws = yield websocket_connect(HTTPRequest(url, headers=headers), io_loop=self.io_loop) @@ -271,8 +274,8 @@ class WebSocketTest(WebSocketBaseTestCase): def test_check_origin_invalid_partial_url(self): port = self.get_http_port() - url = 'ws://localhost:%d/echo' % port - headers = {'Origin': 'localhost:%d' % port} + url = 'ws://127.0.0.1:%d/echo' % port + headers = {'Origin': '127.0.0.1:%d' % port} with self.assertRaises(HTTPError) as cm: yield websocket_connect(HTTPRequest(url, headers=headers), @@ -283,8 +286,8 @@ class WebSocketTest(WebSocketBaseTestCase): def test_check_origin_invalid(self): port = self.get_http_port() - url = 'ws://localhost:%d/echo' % port - # Host is localhost, which should not be accessible from some other + url = 'ws://127.0.0.1:%d/echo' % port + # Host is 127.0.0.1, which should not be accessible from some other # domain headers = {'Origin': 'http://somewhereelse.com'} diff --git a/tornado/testing.py b/tornado/testing.py index 4d85abe997372611e1bc2d1a23287329f967aafd..3d3bcf72b974d9cc0f8231982da00fef19766894 100644 --- a/tornado/testing.py +++ b/tornado/testing.py @@ -19,6 +19,7 @@ try: from tornado.simple_httpclient import SimpleAsyncHTTPClient from tornado.ioloop import IOLoop, TimeoutError from tornado import netutil + from tornado.process import Subprocess except ImportError: # These modules are not importable on app engine. Parts of this module # won't work, but e.g. LogTrapTestCase and main() will. @@ -28,6 +29,7 @@ except ImportError: IOLoop = None netutil = None SimpleAsyncHTTPClient = None + Subprocess = None from tornado.log import gen_log, app_log from tornado.stack_context import ExceptionStackContext from tornado.util import raise_exc_info, basestring_type @@ -214,6 +216,8 @@ class AsyncTestCase(unittest.TestCase): self.io_loop.make_current() def tearDown(self): + # Clean up Subprocess, so it can be used again with a new ioloop. + Subprocess.uninitialize() self.io_loop.clear_current() if (not IOLoop.initialized() or self.io_loop is not IOLoop.instance()): @@ -539,6 +543,9 @@ class LogTrapTestCase(unittest.TestCase): `logging.basicConfig` and the "pretty logging" configured by `tornado.options`. It is not compatible with other log buffering mechanisms, such as those provided by some test runners. + + .. deprecated:: 4.1 + Use the unittest module's ``--buffer`` option instead, or `.ExpectLog`. """ def run(self, result=None): logger = logging.getLogger() diff --git a/tornado/web.py b/tornado/web.py index 2cb5afad6f3b8201e62fe7ac236b8a9b59caa577..52bfce3663f8d1c585e25cff67c3aafd5f9114ab 100644 --- a/tornado/web.py +++ b/tornado/web.py @@ -85,6 +85,7 @@ from tornado import stack_context from tornado import template from tornado.escape import utf8, _unicode from tornado.util import import_object, ObjectDict, raise_exc_info, unicode_type, _websocket_mask +from tornado.httputil import split_host_and_port try: @@ -267,6 +268,7 @@ class RequestHandler(object): if _has_stream_request_body(self.__class__): if not self.request.body.done(): self.request.body.set_exception(iostream.StreamClosedError()) + self.request.body.exception() def clear(self): """Resets all headers and content for this response.""" @@ -839,7 +841,7 @@ class RequestHandler(object): for cookie in self._new_cookie.values(): self.add_header("Set-Cookie", cookie.OutputString(None)) - start_line = httputil.ResponseStartLine(self.request.version, + start_line = httputil.ResponseStartLine('', self._status_code, self._reason) return self.request.connection.write_headers( @@ -1119,28 +1121,36 @@ class RequestHandler(object): """Convert a cookie string into a the tuple form returned by _get_raw_xsrf_token. """ - m = _signed_value_version_re.match(utf8(cookie)) - if m: - version = int(m.group(1)) - if version == 2: - _, mask, masked_token, timestamp = cookie.split("|") - mask = binascii.a2b_hex(utf8(mask)) - token = _websocket_mask( - mask, binascii.a2b_hex(utf8(masked_token))) - timestamp = int(timestamp) - return version, token, timestamp + + try: + m = _signed_value_version_re.match(utf8(cookie)) + + if m: + version = int(m.group(1)) + if version == 2: + _, mask, masked_token, timestamp = cookie.split("|") + + mask = binascii.a2b_hex(utf8(mask)) + token = _websocket_mask( + mask, binascii.a2b_hex(utf8(masked_token))) + timestamp = int(timestamp) + return version, token, timestamp + else: + # Treat unknown versions as not present instead of failing. + raise Exception("Unknown xsrf cookie version") else: - # Treat unknown versions as not present instead of failing. - return None, None, None - else: - version = 1 - try: - token = binascii.a2b_hex(utf8(cookie)) - except (binascii.Error, TypeError): - token = utf8(cookie) - # We don't have a usable timestamp in older versions. - timestamp = int(time.time()) - return (version, token, timestamp) + version = 1 + try: + token = binascii.a2b_hex(utf8(cookie)) + except (binascii.Error, TypeError): + token = utf8(cookie) + # We don't have a usable timestamp in older versions. + timestamp = int(time.time()) + return (version, token, timestamp) + except Exception: + # Catch exceptions and return nothing instead of failing. + gen_log.debug("Uncaught exception in _decode_xsrf_token", exc_info=True) + return None, None, None def check_xsrf_cookie(self): """Verifies that the ``_xsrf`` cookie matches the ``_xsrf`` argument. @@ -1477,7 +1487,7 @@ def asynchronous(method): with stack_context.ExceptionStackContext( self._stack_context_handle_exception): result = method(self, *args, **kwargs) - if isinstance(result, Future): + if is_future(result): # If @asynchronous is used with @gen.coroutine, (but # not @gen.engine), we can automatically finish the # request when the future resolves. Additionally, @@ -1518,7 +1528,7 @@ def stream_request_body(cls): the entire body has been read. There is a subtle interaction between ``data_received`` and asynchronous - ``prepare``: The first call to ``data_recieved`` may occur at any point + ``prepare``: The first call to ``data_received`` may occur at any point after the call to ``prepare`` has returned *or yielded*. """ if not issubclass(cls, RequestHandler): @@ -1729,7 +1739,7 @@ class Application(httputil.HTTPServerConnectionDelegate): self.transforms.append(transform_class) def _get_host_handlers(self, request): - host = request.host.lower().split(':')[0] + host = split_host_and_port(request.host.lower())[0] matches = [] for pattern, handlers in self.handlers: if pattern.match(host): @@ -1770,9 +1780,9 @@ class Application(httputil.HTTPServerConnectionDelegate): except TypeError: pass - def start_request(self, connection): + def start_request(self, server_conn, request_conn): # Modern HTTPServer interface - return _RequestDispatcher(self, connection) + return _RequestDispatcher(self, request_conn) def __call__(self, request): # Legacy HTTPServer interface @@ -1845,7 +1855,7 @@ class _RequestDispatcher(httputil.HTTPMessageDelegate): handlers = app._get_host_handlers(self.request) if not handlers: self.handler_class = RedirectHandler - self.handler_kwargs = dict(url="http://" + app.default_host + "/") + self.handler_kwargs = dict(url="%s://%s/" % (self.request.protocol, app.default_host)) return for spec in handlers: match = spec.regex.match(self.request.path) @@ -1914,8 +1924,10 @@ class _RequestDispatcher(httputil.HTTPMessageDelegate): # trapped in the Future it returns (which we are ignoring here). # However, that shouldn't happen because _execute has a blanket # except handler, and we cannot easily access the IOLoop here to - # call add_future. - self.handler._execute(transforms, *self.path_args, **self.path_kwargs) + # call add_future (because of the requirement to remain compatible + # with WSGI) + f = self.handler._execute(transforms, *self.path_args, **self.path_kwargs) + f.add_done_callback(lambda f: f.exception()) # If we are streaming the request body, then execute() is finished # when the handler has prepared to receive the body. If not, # it doesn't matter when execute() finishes (so we return None) @@ -2017,7 +2029,6 @@ class RedirectHandler(RequestHandler): (r"/oldpath", web.RedirectHandler, {"url": "/newpath"}), ]) """ - def initialize(self, url, permanent=True): self._url = url self._permanent = permanent @@ -2622,6 +2633,8 @@ class UIModule(object): UI modules often execute additional queries, and they can include additional CSS and JavaScript that will be included in the output page, which is automatically inserted on page render. + + Subclasses of UIModule must override the `render` method. """ def __init__(self, handler): self.handler = handler @@ -2634,31 +2647,43 @@ class UIModule(object): return self.handler.current_user def render(self, *args, **kwargs): - """Overridden in subclasses to return this module's output.""" + """Override in subclasses to return this module's output.""" raise NotImplementedError() def embedded_javascript(self): - """Returns a JavaScript string that will be embedded in the page.""" + """Override to return a JavaScript string to be embedded in the page.""" return None def javascript_files(self): - """Returns a list of JavaScript files required by this module.""" + """Override to return a list of JavaScript files needed by this module. + + If the return values are relative paths, they will be passed to + `RequestHandler.static_url`; otherwise they will be used as-is. + """ return None def embedded_css(self): - """Returns a CSS string that will be embedded in the page.""" + """Override to return a CSS string that will be embedded in the page.""" return None def css_files(self): - """Returns a list of CSS files required by this module.""" + """Override to returns a list of CSS files required by this module. + + If the return values are relative paths, they will be passed to + `RequestHandler.static_url`; otherwise they will be used as-is. + """ return None def html_head(self): - """Returns a CSS string that will be put in the <head/> element""" + """Override to return an HTML string that will be put in the <head/> + element. + """ return None def html_body(self): - """Returns an HTML string that will be put in the <body/> element""" + """Override to return an HTML string that will be put at the end of + the <body/> element. + """ return None def render_string(self, path, **kwargs): diff --git a/tornado/websocket.py b/tornado/websocket.py index d960b0e40faa04401b19a34f6e8dbf42ed4a25ce..c009225ce51decfd59ea964855648cbe10babe92 100644 --- a/tornado/websocket.py +++ b/tornado/websocket.py @@ -129,6 +129,7 @@ class WebSocketHandler(tornado.web.RequestHandler): self.close_code = None self.close_reason = None self.stream = None + self._on_close_called = False @tornado.web.asynchronous def get(self, *args, **kwargs): @@ -171,18 +172,16 @@ class WebSocketHandler(tornado.web.RequestHandler): self.stream = self.request.connection.detach() self.stream.set_close_callback(self.on_connection_close) - if self.request.headers.get("Sec-WebSocket-Version") in ("7", "8", "13"): - self.ws_connection = WebSocketProtocol13( - self, compression_options=self.get_compression_options()) + self.ws_connection = self.get_websocket_protocol() + if self.ws_connection: self.ws_connection.accept_connection() else: if not self.stream.closed(): self.stream.write(tornado.escape.utf8( "HTTP/1.1 426 Upgrade Required\r\n" - "Sec-WebSocket-Version: 8\r\n\r\n")) + "Sec-WebSocket-Version: 7, 8, 13\r\n\r\n")) self.stream.close() - def write_message(self, message, binary=False): """Sends the given message to the client of this Web Socket. @@ -229,7 +228,7 @@ class WebSocketHandler(tornado.web.RequestHandler): """ return None - def open(self): + def open(self, *args, **kwargs): """Invoked when a new WebSocket is opened. The arguments to `open` are extracted from the `tornado.web.URLSpec` @@ -350,6 +349,8 @@ class WebSocketHandler(tornado.web.RequestHandler): if self.ws_connection: self.ws_connection.on_connection_close() self.ws_connection = None + if not self._on_close_called: + self._on_close_called self.on_close() def send_error(self, *args, **kwargs): @@ -362,6 +363,13 @@ class WebSocketHandler(tornado.web.RequestHandler): # we can close the connection more gracefully. self.stream.close() + def get_websocket_protocol(self): + websocket_version = self.request.headers.get("Sec-WebSocket-Version") + if websocket_version in ("7", "8", "13"): + return WebSocketProtocol13( + self, compression_options=self.get_compression_options()) + + def _wrap_method(method): def _disallow_for_websocket(self, *args, **kwargs): if self.stream is None: @@ -852,12 +860,15 @@ class WebSocketClientConnection(simple_httpclient._HTTPConnection): This class should not be instantiated directly; use the `websocket_connect` function instead. """ - def __init__(self, io_loop, request, compression_options=None): + def __init__(self, io_loop, request, on_message_callback=None, + compression_options=None): self.compression_options = compression_options self.connect_future = TracebackFuture() + self.protocol = None self.read_future = None self.read_queue = collections.deque() self.key = base64.b64encode(os.urandom(16)) + self._on_message_callback = on_message_callback scheme, sep, rest = request.url.partition(':') scheme = {'ws': 'http', 'wss': 'https'}[scheme] @@ -919,9 +930,7 @@ class WebSocketClientConnection(simple_httpclient._HTTPConnection): start_line, headers) self.headers = headers - self.protocol = WebSocketProtocol13( - self, mask_outgoing=True, - compression_options=self.compression_options) + self.protocol = self.get_websocket_protocol() self.protocol._process_server_headers(self.key, self.headers) self.protocol._receive_frame() @@ -946,6 +955,9 @@ class WebSocketClientConnection(simple_httpclient._HTTPConnection): def read_message(self, callback=None): """Reads a message from the WebSocket server. + If on_message_callback was specified at WebSocket + initialization, this function will never return messages + Returns a future whose result is the message, or None if the connection is closed. If a callback argument is given it will be called with the future when it is @@ -962,7 +974,9 @@ class WebSocketClientConnection(simple_httpclient._HTTPConnection): return future def on_message(self, message): - if self.read_future is not None: + if self._on_message_callback: + self._on_message_callback(message) + elif self.read_future is not None: self.read_future.set_result(message) self.read_future = None else: @@ -971,9 +985,13 @@ class WebSocketClientConnection(simple_httpclient._HTTPConnection): def on_pong(self, data): pass + def get_websocket_protocol(self): + return WebSocketProtocol13(self, mask_outgoing=True, + compression_options=self.compression_options) + def websocket_connect(url, io_loop=None, callback=None, connect_timeout=None, - compression_options=None): + on_message_callback=None, compression_options=None): """Client-side websocket support. Takes a url and returns a Future whose result is a @@ -982,11 +1000,26 @@ def websocket_connect(url, io_loop=None, callback=None, connect_timeout=None, ``compression_options`` is interpreted in the same way as the return value of `.WebSocketHandler.get_compression_options`. + The connection supports two styles of operation. In the coroutine + style, the application typically calls + `~.WebSocketClientConnection.read_message` in a loop:: + + conn = yield websocket_connection(loop) + while True: + msg = yield conn.read_message() + if msg is None: break + # Do something with msg + + In the callback style, pass an ``on_message_callback`` to + ``websocket_connect``. In both styles, a message of ``None`` + indicates that the connection has been closed. + .. versionchanged:: 3.2 Also accepts ``HTTPRequest`` objects in place of urls. .. versionchanged:: 4.1 - Added ``compression_options``. + Added ``compression_options`` and ``on_message_callback``. + The ``io_loop`` argument is deprecated. """ if io_loop is None: io_loop = IOLoop.current() @@ -1000,7 +1033,9 @@ def websocket_connect(url, io_loop=None, callback=None, connect_timeout=None, request = httpclient.HTTPRequest(url, connect_timeout=connect_timeout) request = httpclient._RequestProxy( request, httpclient.HTTPRequest._DEFAULTS) - conn = WebSocketClientConnection(io_loop, request, compression_options) + conn = WebSocketClientConnection(io_loop, request, + on_message_callback=on_message_callback, + compression_options=compression_options) if callback is not None: io_loop.add_future(conn.connect_future, callback) return conn.connect_future diff --git a/tornado/wsgi.py b/tornado/wsgi.py index f3aa66503c3bfadb1b2b0b24cbe8961c70967274..e7e07fbc9cf3e960eebe66b8263f198476a1af35 100644 --- a/tornado/wsgi.py +++ b/tornado/wsgi.py @@ -207,7 +207,7 @@ class WSGIAdapter(object): body = environ["wsgi.input"].read( int(headers["Content-Length"])) else: - body = "" + body = b"" protocol = environ["wsgi.url_scheme"] remote_ip = environ.get("REMOTE_ADDR", "") if environ.get("HTTP_HOST"):