NEW: extra where

This commit is contained in:
2014-11-30 09:50:00 +01:00
parent 0bc5bbbad4
commit 0563b393f6
2 changed files with 40 additions and 19 deletions

View File

@@ -2,7 +2,7 @@ from __future__ import unicode_literals
import datetime import datetime
import copy import copy
VERSION = "0.0.6" VERSION = "0.0.8"
class classproperty(object): class classproperty(object):
@@ -183,7 +183,7 @@ class SQLQuery(object):
if table: if table:
self._table = table self._table = table
self.sql_mode = sql_mode self.sql_mode = sql_mode
self._values = [] self._values = ["*"]
self._order_by = [] self._order_by = []
self._group_by = [] self._group_by = []
self._joins = [] self._joins = []
@@ -197,7 +197,7 @@ class SQLQuery(object):
def values(self, *args): def values(self, *args):
clone = self._clone() clone = self._clone()
clone._values = args clone._values = list(args)
return clone return clone
def filter(self, *args, **kwargs): def filter(self, *args, **kwargs):
@@ -235,9 +235,12 @@ class SQLQuery(object):
clone._joins.append("{how} {table} {on}".format(how=how, table=table, on=on)) clone._joins.append("{how} {table} {on}".format(how=how, table=table, on=on))
return clone return clone
def extra(self, extra): def extra(self, extra=None, **kwargs):
clone = self._clone() clone = self._clone()
clone._extra.update(extra) if extra:
clone._extra.update(extra)
if kwargs:
clone._extra.update(kwargs)
return clone return clone
def __getitem__(self, slice): def __getitem__(self, slice):
@@ -249,22 +252,27 @@ class SQLQuery(object):
class SQLCompiler(object): class SQLCompiler(object):
def get_columns(self,): def get_columns(self,):
if self._values: extra_columns = self.get_extra_columns()
return ", ".join(self._values) columns = ", ".join(self._values)
return "*" return ", ".join([item for item in [columns, extra_columns] if item])
def get_extra_columns(self,): def get_extra_columns(self,):
select = self._extra.get("select", None) return self._extra.get("select", "")
if select:
return ", " + select def get_extra_where(self,):
where = self._extra.get("where", [])
if where:
return " AND ".join(where)
def get_table(self,): def get_table(self,):
return self._table return self._table
def get_where(self): def get_where(self):
filters = self._filters & ~self._excludes filters = unicode(self._filters & ~self._excludes)
if filters: extra_where = self.get_extra_where()
return "WHERE " + unicode(filters)
if filters or extra_where:
return "WHERE " + " ".join([item for item in [filters, extra_where] if item])
def get_order_by(self,): def get_order_by(self,):
conds = [] conds = []
@@ -307,10 +315,11 @@ class SQLCompiler(object):
return "TOP {0}".format(self._limits.stop) return "TOP {0}".format(self._limits.stop)
def _compile(self): def _compile(self):
sql_all = [ sql_all = ["SELECT", self.get_top(), self.get_columns(),
"SELECT", self.get_top(), self.get_columns(), self.get_extra_columns(), "FROM", self.get_table(),
"FROM", self.get_table(), self.get_joins(), self.get_where(), self.get_joins(), self.get_where(),
self.get_group_by(), self.get_order_by(), self.get_limits()] self.get_group_by(), self.get_order_by(),
self.get_limits()]
return " ".join([item for item in sql_all if item]) return " ".join([item for item in sql_all if item])

View File

@@ -66,9 +66,21 @@ class TestSqlBuilder(unittest.TestCase):
self.assertEqual( self.assertEqual(
str(sql), "SELECT name, date, tlf FROM users WHERE ((name='jhon') AND NOT (DATEPART('year', date)<=1977))") str(sql), "SELECT name, date, tlf FROM users WHERE ((name='jhon') AND NOT (DATEPART('year', date)<=1977))")
def test_extra(self):
sql = Queryset("users").values("name", "date", "tlf")
sql = sql.extra({'select': 'count(*) as total'}) sql = sql.extra({'select': 'count(*) as total'})
self.assertEqual( self.assertEqual(
str(sql), "SELECT name, date, tlf , count(*) as total FROM users WHERE ((name='jhon') AND NOT (DATEPART('year', date)<=1977))") str(sql), "SELECT name, date, tlf, count(*) as total FROM users")
sql = Queryset("users")
sql = sql.extra(where=["id=1", "name='jose'"])
sql = sql.extra(select="count(*) as total")
self.assertEqual(
str(sql), "SELECT *, count(*) as total FROM users WHERE id=1 AND name='jose'")
sql = sql.values(*[])
self.assertEqual(
str(sql), "SELECT count(*) as total FROM users WHERE id=1 AND name='jose'")
def test_in(self,): def test_in(self,):
sql = Queryset("users") sql = Queryset("users")