From 68229e343cff442b451488a1a89ac4578d135048 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jos=C3=A9=20S=C3=A1nchez=20Moreno?= Date: Thu, 9 Jul 2015 13:07:56 +0200 Subject: [PATCH] =?UTF-8?q?Versi=C3=B3n=200.11?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- sqlquerybuilder/__init__.py | 123 ++++++++++++++++++++++++++++-------- 1 file changed, 96 insertions(+), 27 deletions(-) diff --git a/sqlquerybuilder/__init__.py b/sqlquerybuilder/__init__.py index 9944e78..46183ae 100644 --- a/sqlquerybuilder/__init__.py +++ b/sqlquerybuilder/__init__.py @@ -2,7 +2,7 @@ from __future__ import unicode_literals import datetime import copy -VERSION = "0.0.9" +VERSION = "0.0.11" class classproperty(object): @@ -18,11 +18,14 @@ class QMixin(object): AND = 'AND' OR = 'OR' NOT = 'NOT' + UNION = 'UNION' def _combine(self, other, conn): return Operator(conn, self, other) def __or__(self, other): + if isinstance(other, Queryset): + return self._combine(other, self.UNION) return self._combine(other, self.OR) def __and__(self, other): @@ -76,6 +79,8 @@ class F(object): class Q(QMixin): + _mode = "MYSQL" + lookup_types = [ 'icontains', 'istartswith', 'iendswith', 'contains', 'startswith', 'endswith', @@ -91,6 +96,8 @@ class Q(QMixin): def __init__(self, *args, **kwargs): self.conditions = kwargs + for arg in args: + self.conditions[arg] = None def __repr__(self,): return self._compile() @@ -100,15 +107,27 @@ class Q(QMixin): __nonzero__ = __bool__ + @property + def date_format(self): + if self._mode == 'SQLSERVER': + return "%Y-%d-%m" + return "%Y-%m-%d" + + @property + def datetime_format(self): + if self._mode == 'SQLSERVER': + return "%Y-%d-%m %H:%M:%S" + return "%Y-%m-%d %H:%M:%S" + def _get_value(self, value): if isinstance(value, int) or isinstance(value, float): return unicode(value) if isinstance(value, datetime.datetime): - return "'%s'" % value.strftime("%Y-%m-%d %H:%M:%S") + return "'%s'" % value.strftime(self.datetime_format) if isinstance(value, datetime.date): - return "'%s'" % value.strftime("%Y-%m-%d") + return "'%s'" % value.strftime(self.date_format) if isinstance(value, list) or isinstance(value, set): return ", ".join([self._get_value(item) for item in value]) @@ -121,6 +140,9 @@ class Q(QMixin): def _process(self, compose_column, value): arr = compose_column.split("__") column = arr.pop(0) + if column == '': + column += "__" + arr.pop(0) + try: lookup = arr.pop(0) except: @@ -164,7 +186,10 @@ class Q(QMixin): if lookup in self.op_map.keys(): return "{0}{1}{2}".format(column, self.op_map[lookup], self._get_value(value)) - return "{0}{1}{2}".format(column, "=", self._get_value(value)) + if value is not None: + return "{0}{1}{2}".format(column, "=", self._get_value(value)) + + return column def _compile(self,): filters = [] @@ -179,9 +204,9 @@ class Q(QMixin): class SQLQuery(object): - def __init__(self, table=None, sql_mode="MYSQL"): - if table: - self._table = table + def __init__(self, table=None, sql_mode="MYSQL", sql=None, **kwargs): + self.kwargs = kwargs + self._table = table self.sql_mode = sql_mode self._values = ["*"] self._order_by = [] @@ -191,6 +216,24 @@ class SQLQuery(object): self._excludes = Q() self._extra = {} self._limits = None + self._sql = sql + + def has_filters(self,): + return self._order_by or self._group_by or self._joins\ + or self._filters or self._excludes or self._extra \ + or self._limits or self._values != ['*'] + + def _q(self, *args, **kwargs): + conds = Q() + conds._mode = self.sql_mode + for arg in args: + if issubclass(arg.__class__, QMixin): + arg._mode = self.sql_mode + conds &= arg + + _conds = Q(**kwargs) + _conds._mode = self.sql_mode + return conds & _conds def _clone(self,): return copy.deepcopy(self) @@ -202,20 +245,12 @@ class SQLQuery(object): def filter(self, *args, **kwargs): clone = self._clone() - for arg in args: - if issubclass(arg.__class__, QMixin): - clone._filters &= arg - - clone._filters &= Q(**kwargs) + clone._filters &= self._q(*args, **kwargs) return clone def exclude(self, *args, **kwargs): clone = self._clone() - for arg in args: - if issubclass(arg.__class__, QMixin): - clone._excludes &= arg - - clone._excludes &= Q(**kwargs) + clone._excludes &= self._q(*args, **kwargs) return clone def order_by(self, *args): @@ -231,7 +266,7 @@ class SQLQuery(object): def join(self, table, on="", how="inner join"): clone = self._clone() if on: - on = "ON " + on + on = "ON " + on.format(table=self._table) clone._joins.append("{how} {table} {on}".format(how=how, table=table, on=on)) return clone @@ -272,12 +307,15 @@ class SQLCompiler(object): extra_where = self.get_extra_where() if filters or extra_where: - return "WHERE " + " ".join([item for item in [filters, extra_where] if item]) + return "WHERE " + " AND ".join([item for item in [filters, extra_where] if item]) def get_order_by(self,): conds = [] for cond in self._order_by: order = "" + if cond is None: + continue + column = cond try: if cond[0] == "-": @@ -311,17 +349,45 @@ class SQLCompiler(object): return str def get_top(self,): - if self._limits and self.sql_mode == "SQL_SERVER": + if self._limits and self.sql_mode == "SQL_SERVER" and not self._limits.start: return "TOP {0}".format(self._limits.stop) - def _compile(self): - sql_all = ["SELECT", self.get_top(), self.get_columns(), - "FROM", self.get_table(), - self.get_joins(), self.get_where(), - self.get_group_by(), self.get_order_by(), - self.get_limits()] + def get_sql_structure(self): + if self._sql: + if not self.has_filters(): + return [self._sql] + table = "(%s) as union1" % self._sql + else: + table = self.get_table() - return " ".join([item for item in sql_all if item]) + sql = ["SELECT", self.get_top(), self.get_columns(), + "FROM", table, "WITH (NOLOCK)", + self.get_joins(), self.get_where(), + self.get_group_by(), self.get_order_by(), + self.get_limits()] + + if self.sql_mode == "SQL_SERVER" and self._limits and \ + self._limits.start is not None and self._limits.stop is not None: + conds = [] + if self._limits.start is not None: + conds.append("row_number > %s" % self._limits.start) + + if self._limits.stop is not None: + conds.append("row_number <= %s" % self._limits.stop) + + conds = " AND ".join(conds) + paginate = "ROW_NUMBER() OVER (%s) as row_number" % self.get_order_by() + + return ["SELECT * FROM (", "SELECT", ",".join([paginate, self.get_columns()]), + "FROM", table, + self.get_joins(), self.get_where(), + self.get_group_by(), + self.get_limits(), ") as tbl_paginated WHERE ", conds] + + return sql + + def _compile(self): + return " ".join([unicode(item) for item in self.get_sql_structure() if item]) def __repr__(self): return self._compile() @@ -332,6 +398,9 @@ class SQLCompiler(object): def sql(self,): return self.__str__() + def __or__(self, other): + return self.__class__(sql="%s UNION %s" % (self, other)) + class Queryset(SQLCompiler, SQLQuery): pass