Project

General

Profile

« Previous | Next » 

Revision 2101

sql.py: DbConn: Fixed bug where schemas db_config value needed to be split apart into strings. Fixed bug where current_setting() returned a value rather than an identifier, so it had to be used with set_config() instead of SET, and run after SET TRANSACTION ISOLATION LEVEL. Moved Input validation section before Database connections because it's used by Database connections.

View differences:

sql.py
71 71
    try: return value(cur)
72 72
    except StopIteration: return None
73 73

  
74
##### Input validation
75

  
76
def clean_name(name): return re.sub(r'\W', r'', name)
77

  
78
def check_name(name):
79
    if re.search(r'\W', name) != None: raise NameException('Name "'+name
80
        +'" may contain only alphanumeric characters and _')
81

  
82
def esc_name_by_module(module, name, ignore_case=False):
83
    if module == 'psycopg2':
84
        if ignore_case:
85
            # Don't enclose in quotes because this disables case-insensitivity
86
            check_name(name)
87
            return name
88
        else: quote = '"'
89
    elif module == 'MySQLdb': quote = '`'
90
    else: raise NotImplementedError("Can't escape name for "+module+' database')
91
    return quote + name.replace(quote, '') + quote
92

  
93
def esc_name_by_engine(engine, name, **kw_args):
94
    return esc_name_by_module(db_engines[engine][0], name, **kw_args)
95

  
96
def esc_name(db, name, **kw_args):
97
    return esc_name_by_module(util.root_module(db.db), name, **kw_args)
98

  
99
def qual_name(db, schema, table):
100
    def esc_name_(name): return esc_name(db, name)
101
    table = esc_name_(table)
102
    if schema != None: return esc_name_(schema)+'.'+table
103
    else: return table
104

  
74 105
##### Database connections
75 106

  
76 107
db_config_names = ['engine', 'host', 'user', 'password', 'database', 'schemas']
......
133 164
            self.__db = module.connect(**db_config)
134 165
            
135 166
            # Configure connection
136
            if schemas != None:
137
                schemas = schemas[:] # don't modify input!
138
                schemas.append('current_setting(search_path)')
139
                run_raw_query(self, 'SET search_path = '+(', '.join(schemas)))
140 167
            if self.serializable: run_raw_query(self,
141 168
                'SET TRANSACTION ISOLATION LEVEL SERIALIZABLE')
169
            if schemas != None:
170
                schemas_ = ''.join((esc_name(self, s)+', '
171
                    for s in schemas.split(',')))
172
                run_raw_query(self, "SELECT set_config('search_path', \
173
%s || current_setting('search_path'), false)", [schemas_])
142 174
        
143 175
        return self.__db
144 176
    
......
225 257

  
226 258
connect = DbConn
227 259

  
228
##### Input validation
229

  
230
def clean_name(name): return re.sub(r'\W', r'', name)
231

  
232
def check_name(name):
233
    if re.search(r'\W', name) != None: raise NameException('Name "'+name
234
        +'" may contain only alphanumeric characters and _')
235

  
236
def esc_name_by_module(module, name, ignore_case=False):
237
    if module == 'psycopg2':
238
        if ignore_case:
239
            # Don't enclose in quotes because this disables case-insensitivity
240
            check_name(name)
241
            return name
242
        else: quote = '"'
243
    elif module == 'MySQLdb': quote = '`'
244
    else: raise NotImplementedError("Can't escape name for "+module+' database')
245
    return quote + name.replace(quote, '') + quote
246

  
247
def esc_name_by_engine(engine, name, **kw_args):
248
    return esc_name_by_module(db_engines[engine][0], name, **kw_args)
249

  
250
def esc_name(db, name, **kw_args):
251
    return esc_name_by_module(util.root_module(db.db), name, **kw_args)
252

  
253
def qual_name(db, schema, table):
254
    def esc_name_(name): return esc_name(db, name)
255
    table = esc_name_(table)
256
    if schema != None: return esc_name_(schema)+'.'+table
257
    else: return table
258

  
259 260
##### Querying
260 261

  
261 262
def run_raw_query(db, *args, **kw_args):

Also available in: Unified diff