diff --git a/__init__.py b/__init__.py index 78f9b5e..be1d1ca 100644 --- a/__init__.py +++ b/__init__.py @@ -66,6 +66,22 @@ class Operator(QMixin): __nonzero__ = __bool__ +class F(object): + + def __init__(self, value): + self.value = value + + def __repr__(self,): + return "%s" % self.value + + __str__ = __repr__ + + def __bool__(self): + return bool(self.value) + + __nonzero__ = __bool__ + + class Q(QMixin): lookup_types = [ 'iexact', 'contains', 'icontains', @@ -101,6 +117,9 @@ class Q(QMixin): if isinstance(value, datetime.date): return "'%s'" % value.strftime("%Y-%m-%d") + if isinstance(value, F): + return value + return "'%s'" % value def _process(self, compose_column, value): @@ -152,9 +171,10 @@ class Q(QMixin): class SQLQuery(object): - def __init__(self, table=None): + def __init__(self, table=None, sql_mode="MYSQL"): if table: self._table = table + self.sql_mode = sql_mode self._values = [] self._order_by = [] self._group_by = [] @@ -218,7 +238,6 @@ class SQLCompiler(object): select = self._extra.get("select", None) if select: return ", " + select - return "" def get_table(self,): return self._table @@ -228,8 +247,6 @@ class SQLCompiler(object): if filters: return "WHERE " + str(filters) - return "" - def get_order_by(self,): conds = [] for cond in self._order_by: @@ -237,29 +254,26 @@ class SQLCompiler(object): column = cond try: if cond[0] == "-": - order = "desc" + order = " DESC" column = cond[1:] except: pass - conds.append("{0} {1}".format(column, order)) + conds.append("{0}{1}".format(column, order)) if conds: return "ORDER BY " + ", ".join(conds) - return "" def get_group_by(self,): if self._group_by: return "GROUP BY " + ", ".join(self._group_by) - return "" def get_joins(self,): if self._joins: return " ".join(self._joins) - return "" def get_limits(self,): - if self._limits: + if self._limits and self.sql_mode != "SQL_SERVER": offset = self._limits.start limit = self._limits.stop if offset: @@ -268,28 +282,18 @@ class SQLCompiler(object): if offset: str += " OFFSET {0}".format(offset) return str - return "" + + def get_top(self,): + if self._limits and self.sql_mode == "SQL_SERVER": + return "TOP {0}".format(self._limits.stop) def _compile(self): - sql = """ - SELECT {columns}{extra_columns} - FROM {table} - {joins} - {where} - {group_by} - {order_by} - {limits} - """.format( - columns=self.get_columns(), - extra_columns=self.get_extra_columns(), - table=self.get_table(), - joins=self.get_joins(), - where=self.get_where(), - group_by=self.get_group_by(), - order_by=self.get_order_by(), - limits=self.get_limits() - ) - return sql + sql_all = [ + "SELECT", self.get_top(), self.get_columns(), self.get_extra_columns(), + "FROM", self.get_table(), self.get_joins(), self.get_where(), + self.get_group_by(), self.get_order_by(), self.get_limits()] + + return " ".join([item for item in sql_all if item]) def __repr__(self): return self._compile() @@ -305,4 +309,4 @@ class SQLModel(object): @classproperty def objects(cls): - return Queryset(cls.table) + return Queryset(cls.table, getattr(cls, 'sql_mode', None)) diff --git a/tests.py b/tests.py index d0870e7..a5ebd91 100644 --- a/tests.py +++ b/tests.py @@ -1,6 +1,6 @@ import unittest import datetime -from . import Q +from . import Q, Queryset, F class TestSqlBuilder(unittest.TestCase): @@ -22,6 +22,27 @@ class TestSqlBuilder(unittest.TestCase): self.assertEqual(str(Q(fecha__year__lte=2012)), "(DATEPART('year', fecha)<=2012)") self.assertEqual(str(Q(fecha__year=2012)), "(DATEPART('year', fecha)=2012)") + def test_limits(self): + self.assertEqual(Queryset("table")[:10].get_limits(), "LIMIT 10") + self.assertEqual(Queryset("table")[1:10].get_limits(), "LIMIT 9 OFFSET 1") + + def test_compound(self): + qs = Queryset("users", "SQL_SERVER")\ + .filter(nombre="jose")\ + .order_by( "nombre", "-fecha")\ + .filter(fecha__lte=F("now()"))[:10] + + self.assertEqual( + str(qs), "SELECT TOP 10 * FROM users WHERE ((nombre='jose') AND (fecha<=now())) ORDER BY nombre, fecha DESC") + + qs = Queryset("users")\ + .filter(nombre="jose")\ + .order_by( "nombre", "-fecha")\ + .filter(fecha__lte=F("now()"))[:10] + + self.assertEqual( + str(qs), "SELECT * FROM users WHERE ((nombre='jose') AND (fecha<=now())) ORDER BY nombre, fecha DESC LIMIT 10") + if __name__ == '__main__': unittest.main()