NEW: extra where

This commit is contained in:
2014-11-30 09:50:00 +01:00
parent 0bc5bbbad4
commit 0563b393f6
2 changed files with 40 additions and 19 deletions

View File

@@ -2,7 +2,7 @@ from __future__ import unicode_literals
import datetime
import copy
VERSION = "0.0.6"
VERSION = "0.0.8"
class classproperty(object):
@@ -183,7 +183,7 @@ class SQLQuery(object):
if table:
self._table = table
self.sql_mode = sql_mode
self._values = []
self._values = ["*"]
self._order_by = []
self._group_by = []
self._joins = []
@@ -197,7 +197,7 @@ class SQLQuery(object):
def values(self, *args):
clone = self._clone()
clone._values = args
clone._values = list(args)
return clone
def filter(self, *args, **kwargs):
@@ -235,9 +235,12 @@ class SQLQuery(object):
clone._joins.append("{how} {table} {on}".format(how=how, table=table, on=on))
return clone
def extra(self, extra):
def extra(self, extra=None, **kwargs):
clone = self._clone()
clone._extra.update(extra)
if extra:
clone._extra.update(extra)
if kwargs:
clone._extra.update(kwargs)
return clone
def __getitem__(self, slice):
@@ -249,22 +252,27 @@ class SQLQuery(object):
class SQLCompiler(object):
def get_columns(self,):
if self._values:
return ", ".join(self._values)
return "*"
extra_columns = self.get_extra_columns()
columns = ", ".join(self._values)
return ", ".join([item for item in [columns, extra_columns] if item])
def get_extra_columns(self,):
select = self._extra.get("select", None)
if select:
return ", " + select
return self._extra.get("select", "")
def get_extra_where(self,):
where = self._extra.get("where", [])
if where:
return " AND ".join(where)
def get_table(self,):
return self._table
def get_where(self):
filters = self._filters & ~self._excludes
if filters:
return "WHERE " + unicode(filters)
filters = unicode(self._filters & ~self._excludes)
extra_where = self.get_extra_where()
if filters or extra_where:
return "WHERE " + " ".join([item for item in [filters, extra_where] if item])
def get_order_by(self,):
conds = []
@@ -307,10 +315,11 @@ class SQLCompiler(object):
return "TOP {0}".format(self._limits.stop)
def _compile(self):
sql_all = [
"SELECT", self.get_top(), self.get_columns(), self.get_extra_columns(),
"FROM", self.get_table(), self.get_joins(), self.get_where(),
self.get_group_by(), self.get_order_by(), self.get_limits()]
sql_all = ["SELECT", self.get_top(), self.get_columns(),
"FROM", self.get_table(),
self.get_joins(), self.get_where(),
self.get_group_by(), self.get_order_by(),
self.get_limits()]
return " ".join([item for item in sql_all if item])

View File

@@ -66,9 +66,21 @@ class TestSqlBuilder(unittest.TestCase):
self.assertEqual(
str(sql), "SELECT name, date, tlf FROM users WHERE ((name='jhon') AND NOT (DATEPART('year', date)<=1977))")
def test_extra(self):
sql = Queryset("users").values("name", "date", "tlf")
sql = sql.extra({'select': 'count(*) as total'})
self.assertEqual(
str(sql), "SELECT name, date, tlf , count(*) as total FROM users WHERE ((name='jhon') AND NOT (DATEPART('year', date)<=1977))")
str(sql), "SELECT name, date, tlf, count(*) as total FROM users")
sql = Queryset("users")
sql = sql.extra(where=["id=1", "name='jose'"])
sql = sql.extra(select="count(*) as total")
self.assertEqual(
str(sql), "SELECT *, count(*) as total FROM users WHERE id=1 AND name='jose'")
sql = sql.values(*[])
self.assertEqual(
str(sql), "SELECT count(*) as total FROM users WHERE id=1 AND name='jose'")
def test_in(self,):
sql = Queryset("users")