Mejoramos los test

This commit is contained in:
2014-11-25 00:12:03 +01:00
parent b595f7f124
commit be92a456e7
2 changed files with 57 additions and 32 deletions

View File

@@ -66,6 +66,22 @@ class Operator(QMixin):
__nonzero__ = __bool__ __nonzero__ = __bool__
class F(object):
def __init__(self, value):
self.value = value
def __repr__(self,):
return "%s" % self.value
__str__ = __repr__
def __bool__(self):
return bool(self.value)
__nonzero__ = __bool__
class Q(QMixin): class Q(QMixin):
lookup_types = [ lookup_types = [
'iexact', 'contains', 'icontains', 'iexact', 'contains', 'icontains',
@@ -101,6 +117,9 @@ class Q(QMixin):
if isinstance(value, datetime.date): if isinstance(value, datetime.date):
return "'%s'" % value.strftime("%Y-%m-%d") return "'%s'" % value.strftime("%Y-%m-%d")
if isinstance(value, F):
return value
return "'%s'" % value return "'%s'" % value
def _process(self, compose_column, value): def _process(self, compose_column, value):
@@ -152,9 +171,10 @@ class Q(QMixin):
class SQLQuery(object): class SQLQuery(object):
def __init__(self, table=None): def __init__(self, table=None, sql_mode="MYSQL"):
if table: if table:
self._table = table self._table = table
self.sql_mode = sql_mode
self._values = [] self._values = []
self._order_by = [] self._order_by = []
self._group_by = [] self._group_by = []
@@ -218,7 +238,6 @@ class SQLCompiler(object):
select = self._extra.get("select", None) select = self._extra.get("select", None)
if select: if select:
return ", " + select return ", " + select
return ""
def get_table(self,): def get_table(self,):
return self._table return self._table
@@ -228,8 +247,6 @@ class SQLCompiler(object):
if filters: if filters:
return "WHERE " + str(filters) return "WHERE " + str(filters)
return ""
def get_order_by(self,): def get_order_by(self,):
conds = [] conds = []
for cond in self._order_by: for cond in self._order_by:
@@ -237,29 +254,26 @@ class SQLCompiler(object):
column = cond column = cond
try: try:
if cond[0] == "-": if cond[0] == "-":
order = "desc" order = " DESC"
column = cond[1:] column = cond[1:]
except: except:
pass pass
conds.append("{0} {1}".format(column, order)) conds.append("{0}{1}".format(column, order))
if conds: if conds:
return "ORDER BY " + ", ".join(conds) return "ORDER BY " + ", ".join(conds)
return ""
def get_group_by(self,): def get_group_by(self,):
if self._group_by: if self._group_by:
return "GROUP BY " + ", ".join(self._group_by) return "GROUP BY " + ", ".join(self._group_by)
return ""
def get_joins(self,): def get_joins(self,):
if self._joins: if self._joins:
return " ".join(self._joins) return " ".join(self._joins)
return ""
def get_limits(self,): def get_limits(self,):
if self._limits: if self._limits and self.sql_mode != "SQL_SERVER":
offset = self._limits.start offset = self._limits.start
limit = self._limits.stop limit = self._limits.stop
if offset: if offset:
@@ -268,28 +282,18 @@ class SQLCompiler(object):
if offset: if offset:
str += " OFFSET {0}".format(offset) str += " OFFSET {0}".format(offset)
return str return str
return ""
def get_top(self,):
if self._limits and self.sql_mode == "SQL_SERVER":
return "TOP {0}".format(self._limits.stop)
def _compile(self): def _compile(self):
sql = """ sql_all = [
SELECT {columns}{extra_columns} "SELECT", self.get_top(), self.get_columns(), self.get_extra_columns(),
FROM {table} "FROM", self.get_table(), self.get_joins(), self.get_where(),
{joins} self.get_group_by(), self.get_order_by(), self.get_limits()]
{where}
{group_by} return " ".join([item for item in sql_all if item])
{order_by}
{limits}
""".format(
columns=self.get_columns(),
extra_columns=self.get_extra_columns(),
table=self.get_table(),
joins=self.get_joins(),
where=self.get_where(),
group_by=self.get_group_by(),
order_by=self.get_order_by(),
limits=self.get_limits()
)
return sql
def __repr__(self): def __repr__(self):
return self._compile() return self._compile()
@@ -305,4 +309,4 @@ class SQLModel(object):
@classproperty @classproperty
def objects(cls): def objects(cls):
return Queryset(cls.table) return Queryset(cls.table, getattr(cls, 'sql_mode', None))

View File

@@ -1,6 +1,6 @@
import unittest import unittest
import datetime import datetime
from . import Q from . import Q, Queryset, F
class TestSqlBuilder(unittest.TestCase): class TestSqlBuilder(unittest.TestCase):
@@ -22,6 +22,27 @@ class TestSqlBuilder(unittest.TestCase):
self.assertEqual(str(Q(fecha__year__lte=2012)), "(DATEPART('year', fecha)<=2012)") self.assertEqual(str(Q(fecha__year__lte=2012)), "(DATEPART('year', fecha)<=2012)")
self.assertEqual(str(Q(fecha__year=2012)), "(DATEPART('year', fecha)=2012)") self.assertEqual(str(Q(fecha__year=2012)), "(DATEPART('year', fecha)=2012)")
def test_limits(self):
self.assertEqual(Queryset("table")[:10].get_limits(), "LIMIT 10")
self.assertEqual(Queryset("table")[1:10].get_limits(), "LIMIT 9 OFFSET 1")
def test_compound(self):
qs = Queryset("users", "SQL_SERVER")\
.filter(nombre="jose")\
.order_by( "nombre", "-fecha")\
.filter(fecha__lte=F("now()"))[:10]
self.assertEqual(
str(qs), "SELECT TOP 10 * FROM users WHERE ((nombre='jose') AND (fecha<=now())) ORDER BY nombre, fecha DESC")
qs = Queryset("users")\
.filter(nombre="jose")\
.order_by( "nombre", "-fecha")\
.filter(fecha__lte=F("now()"))[:10]
self.assertEqual(
str(qs), "SELECT * FROM users WHERE ((nombre='jose') AND (fecha<=now())) ORDER BY nombre, fecha DESC LIMIT 10")
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()