NEW: extra where
This commit is contained in:
@@ -2,7 +2,7 @@ from __future__ import unicode_literals
|
|||||||
import datetime
|
import datetime
|
||||||
import copy
|
import copy
|
||||||
|
|
||||||
VERSION = "0.0.6"
|
VERSION = "0.0.8"
|
||||||
|
|
||||||
|
|
||||||
class classproperty(object):
|
class classproperty(object):
|
||||||
@@ -183,7 +183,7 @@ class SQLQuery(object):
|
|||||||
if table:
|
if table:
|
||||||
self._table = table
|
self._table = table
|
||||||
self.sql_mode = sql_mode
|
self.sql_mode = sql_mode
|
||||||
self._values = []
|
self._values = ["*"]
|
||||||
self._order_by = []
|
self._order_by = []
|
||||||
self._group_by = []
|
self._group_by = []
|
||||||
self._joins = []
|
self._joins = []
|
||||||
@@ -197,7 +197,7 @@ class SQLQuery(object):
|
|||||||
|
|
||||||
def values(self, *args):
|
def values(self, *args):
|
||||||
clone = self._clone()
|
clone = self._clone()
|
||||||
clone._values = args
|
clone._values = list(args)
|
||||||
return clone
|
return clone
|
||||||
|
|
||||||
def filter(self, *args, **kwargs):
|
def filter(self, *args, **kwargs):
|
||||||
@@ -235,9 +235,12 @@ class SQLQuery(object):
|
|||||||
clone._joins.append("{how} {table} {on}".format(how=how, table=table, on=on))
|
clone._joins.append("{how} {table} {on}".format(how=how, table=table, on=on))
|
||||||
return clone
|
return clone
|
||||||
|
|
||||||
def extra(self, extra):
|
def extra(self, extra=None, **kwargs):
|
||||||
clone = self._clone()
|
clone = self._clone()
|
||||||
clone._extra.update(extra)
|
if extra:
|
||||||
|
clone._extra.update(extra)
|
||||||
|
if kwargs:
|
||||||
|
clone._extra.update(kwargs)
|
||||||
return clone
|
return clone
|
||||||
|
|
||||||
def __getitem__(self, slice):
|
def __getitem__(self, slice):
|
||||||
@@ -249,22 +252,27 @@ class SQLQuery(object):
|
|||||||
class SQLCompiler(object):
|
class SQLCompiler(object):
|
||||||
|
|
||||||
def get_columns(self,):
|
def get_columns(self,):
|
||||||
if self._values:
|
extra_columns = self.get_extra_columns()
|
||||||
return ", ".join(self._values)
|
columns = ", ".join(self._values)
|
||||||
return "*"
|
return ", ".join([item for item in [columns, extra_columns] if item])
|
||||||
|
|
||||||
def get_extra_columns(self,):
|
def get_extra_columns(self,):
|
||||||
select = self._extra.get("select", None)
|
return self._extra.get("select", "")
|
||||||
if select:
|
|
||||||
return ", " + select
|
def get_extra_where(self,):
|
||||||
|
where = self._extra.get("where", [])
|
||||||
|
if where:
|
||||||
|
return " AND ".join(where)
|
||||||
|
|
||||||
def get_table(self,):
|
def get_table(self,):
|
||||||
return self._table
|
return self._table
|
||||||
|
|
||||||
def get_where(self):
|
def get_where(self):
|
||||||
filters = self._filters & ~self._excludes
|
filters = unicode(self._filters & ~self._excludes)
|
||||||
if filters:
|
extra_where = self.get_extra_where()
|
||||||
return "WHERE " + unicode(filters)
|
|
||||||
|
if filters or extra_where:
|
||||||
|
return "WHERE " + " ".join([item for item in [filters, extra_where] if item])
|
||||||
|
|
||||||
def get_order_by(self,):
|
def get_order_by(self,):
|
||||||
conds = []
|
conds = []
|
||||||
@@ -307,10 +315,11 @@ class SQLCompiler(object):
|
|||||||
return "TOP {0}".format(self._limits.stop)
|
return "TOP {0}".format(self._limits.stop)
|
||||||
|
|
||||||
def _compile(self):
|
def _compile(self):
|
||||||
sql_all = [
|
sql_all = ["SELECT", self.get_top(), self.get_columns(),
|
||||||
"SELECT", self.get_top(), self.get_columns(), self.get_extra_columns(),
|
"FROM", self.get_table(),
|
||||||
"FROM", self.get_table(), self.get_joins(), self.get_where(),
|
self.get_joins(), self.get_where(),
|
||||||
self.get_group_by(), self.get_order_by(), self.get_limits()]
|
self.get_group_by(), self.get_order_by(),
|
||||||
|
self.get_limits()]
|
||||||
|
|
||||||
return " ".join([item for item in sql_all if item])
|
return " ".join([item for item in sql_all if item])
|
||||||
|
|
||||||
|
|||||||
@@ -66,9 +66,21 @@ class TestSqlBuilder(unittest.TestCase):
|
|||||||
self.assertEqual(
|
self.assertEqual(
|
||||||
str(sql), "SELECT name, date, tlf FROM users WHERE ((name='jhon') AND NOT (DATEPART('year', date)<=1977))")
|
str(sql), "SELECT name, date, tlf FROM users WHERE ((name='jhon') AND NOT (DATEPART('year', date)<=1977))")
|
||||||
|
|
||||||
|
def test_extra(self):
|
||||||
|
sql = Queryset("users").values("name", "date", "tlf")
|
||||||
sql = sql.extra({'select': 'count(*) as total'})
|
sql = sql.extra({'select': 'count(*) as total'})
|
||||||
self.assertEqual(
|
self.assertEqual(
|
||||||
str(sql), "SELECT name, date, tlf , count(*) as total FROM users WHERE ((name='jhon') AND NOT (DATEPART('year', date)<=1977))")
|
str(sql), "SELECT name, date, tlf, count(*) as total FROM users")
|
||||||
|
|
||||||
|
sql = Queryset("users")
|
||||||
|
sql = sql.extra(where=["id=1", "name='jose'"])
|
||||||
|
sql = sql.extra(select="count(*) as total")
|
||||||
|
self.assertEqual(
|
||||||
|
str(sql), "SELECT *, count(*) as total FROM users WHERE id=1 AND name='jose'")
|
||||||
|
|
||||||
|
sql = sql.values(*[])
|
||||||
|
self.assertEqual(
|
||||||
|
str(sql), "SELECT count(*) as total FROM users WHERE id=1 AND name='jose'")
|
||||||
|
|
||||||
def test_in(self,):
|
def test_in(self,):
|
||||||
sql = Queryset("users")
|
sql = Queryset("users")
|
||||||
|
|||||||
Reference in New Issue
Block a user