Private GIT

Skip to content
Snippets Groups Projects
Commit 9702bd4f authored by bone's avatar bone Committed by miigotu
Browse files

db.py: DBConnection

	manipulation of row_factory on sql connection was not done in a thread safe manner

	factory methods should be static to avoid holding a reference
parent 574b10ca
No related branches found
No related tags found
No related merge requests found
...@@ -59,29 +59,36 @@ class DBConnection(object): ...@@ -59,29 +59,36 @@ class DBConnection(object):
db_locks[self.filename] = threading.Lock() db_locks[self.filename] = threading.Lock()
self.connection = sqlite3.connect(dbFilename(self.filename, self.suffix), 20, check_same_thread=False) self.connection = sqlite3.connect(dbFilename(self.filename, self.suffix), 20, check_same_thread=False)
self.connection.text_factory = self._unicode_text_factory self.connection.text_factory = DBConnection._unicode_text_factory
db_cons[self.filename] = self.connection db_cons[self.filename] = self.connection
else: else:
self.connection = db_cons[self.filename] self.connection = db_cons[self.filename]
if self.row_type == "dict": # start off row factory configured as before out of
self.connection.row_factory = self._dict_factory # paranoia but wait to do so until other potential users
else: # of the shared connection are done using
self.connection.row_factory = sqlite3.Row # it... technically not required as row factory is reset
# in all the public methods after the lock has been
# aquired
with db_locks[self.filename]:
self._set_row_factory()
except Exception as e: except Exception as e:
logger.log(u"DB error: " + ex(e), logger.ERROR) logger.log(u"DB error: " + ex(e), logger.ERROR)
raise raise
def _execute(self, query, args): def _set_row_factory(self):
try: """
if not args: once lock is aquired we can configure the connection for
return self.connection.cursor().execute(query) this particular instance of DBConnection
return self.connection.cursor().execute(query, args) """
except Exception: if self.row_type == "dict":
raise self.connection.row_factory = DBConnection._dict_factory
else:
self.connection.row_factory = sqlite3.Row
def execute(self, query, args=None, fetchall=False, fetchone=False): def _execute(self, query, args=None, fetchall=False, fetchone=False):
""" """
Executes DB query Executes DB query
...@@ -92,12 +99,16 @@ class DBConnection(object): ...@@ -92,12 +99,16 @@ class DBConnection(object):
:return: query results :return: query results
""" """
try: try:
if not args:
sqlResult = self.connection.cursor().execute(query)
else:
sqlResult = self.connection.cursor().execute(query, args)
if fetchall: if fetchall:
return self._execute(query, args).fetchall() return sqlResult.fetchall()
elif fetchone: elif fetchone:
return self._execute(query, args).fetchone() return sqlResult.fetchone()
else: else:
return self._execute(query, args) return sqlResult
except Exception: except Exception:
raise raise
...@@ -136,17 +147,18 @@ class DBConnection(object): ...@@ -136,17 +147,18 @@ class DBConnection(object):
attempt = 0 attempt = 0
with db_locks[self.filename]: with db_locks[self.filename]:
self._set_row_factory()
while attempt < 5: while attempt < 5:
try: try:
for qu in querylist: for qu in querylist:
if len(qu) == 1: if len(qu) == 1:
if logTransaction: if logTransaction:
logger.log(qu[0], logger.DEBUG) logger.log(qu[0], logger.DEBUG)
sqlResult.append(self.execute(qu[0], fetchall=fetchall)) sqlResult.append(self._execute(qu[0], fetchall=fetchall))
elif len(qu) > 1: elif len(qu) > 1:
if logTransaction: if logTransaction:
logger.log(qu[0] + " with args " + str(qu[1]), logger.DEBUG) logger.log(qu[0] + " with args " + str(qu[1]), logger.DEBUG)
sqlResult.append(self.execute(qu[0], qu[1], fetchall=fetchall)) sqlResult.append(self._execute(qu[0], qu[1], fetchall=fetchall))
self.connection.commit() self.connection.commit()
logger.log(u"Transaction with " + str(len(querylist)) + u" queries executed", logger.DEBUG) logger.log(u"Transaction with " + str(len(querylist)) + u" queries executed", logger.DEBUG)
...@@ -191,6 +203,7 @@ class DBConnection(object): ...@@ -191,6 +203,7 @@ class DBConnection(object):
attempt = 0 attempt = 0
with db_locks[self.filename]: with db_locks[self.filename]:
self._set_row_factory()
while attempt < 5: while attempt < 5:
try: try:
if args is None: if args is None:
...@@ -198,7 +211,7 @@ class DBConnection(object): ...@@ -198,7 +211,7 @@ class DBConnection(object):
else: else:
logger.log(self.filename + ": " + query + " with args " + str(args), logger.DB) logger.log(self.filename + ": " + query + " with args " + str(args), logger.DB)
sqlResult = self.execute(query, args, fetchall=fetchall, fetchone=fetchone) sqlResult = self._execute(query, args, fetchall=fetchall, fetchone=fetchone)
self.connection.commit() self.connection.commit()
# get out of the connection attempt loop since we were successful # get out of the connection attempt loop since we were successful
...@@ -287,7 +300,8 @@ class DBConnection(object): ...@@ -287,7 +300,8 @@ class DBConnection(object):
columns[column['name']] = {'type': column['type']} columns[column['name']] = {'type': column['type']}
return columns return columns
def _unicode_text_factory(self, x): @staticmethod
def _unicode_text_factory(x):
""" """
Convert text to unicode Convert text to unicode
...@@ -300,7 +314,8 @@ class DBConnection(object): ...@@ -300,7 +314,8 @@ class DBConnection(object):
except: except:
return unicode(x, sickbeard.SYS_ENCODING, errors="ignore") return unicode(x, sickbeard.SYS_ENCODING, errors="ignore")
def _dict_factory(self, cursor, row): @staticmethod
def _dict_factory(cursor, row):
d = {} d = {}
for idx, col in enumerate(cursor.description): for idx, col in enumerate(cursor.description):
d[col[0]] = row[idx] d[col[0]] = row[idx]
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment