ddl_references.py 8.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270
  1. """
  2. Helpers to manipulate deferred DDL statements that might need to be adjusted or
  3. discarded within when executing a migration.
  4. """
  5. from copy import deepcopy
  6. class Reference:
  7. """Base class that defines the reference interface."""
  8. def references_table(self, table):
  9. """
  10. Return whether or not this instance references the specified table.
  11. """
  12. return False
  13. def references_column(self, table, column):
  14. """
  15. Return whether or not this instance references the specified column.
  16. """
  17. return False
  18. def references_index(self, table, index):
  19. """
  20. Return whether or not this instance references the specified index.
  21. """
  22. return False
  23. def rename_table_references(self, old_table, new_table):
  24. """
  25. Rename all references to the old_name to the new_table.
  26. """
  27. pass
  28. def rename_column_references(self, table, old_column, new_column):
  29. """
  30. Rename all references to the old_column to the new_column.
  31. """
  32. pass
  33. def __repr__(self):
  34. return "<%s %r>" % (self.__class__.__name__, str(self))
  35. def __str__(self):
  36. raise NotImplementedError(
  37. "Subclasses must define how they should be converted to string."
  38. )
  39. class Table(Reference):
  40. """Hold a reference to a table."""
  41. def __init__(self, table, quote_name):
  42. self.table = table
  43. self.quote_name = quote_name
  44. def references_table(self, table):
  45. return self.table == table
  46. def references_index(self, table, index):
  47. return self.references_table(table) and str(self) == index
  48. def rename_table_references(self, old_table, new_table):
  49. if self.table == old_table:
  50. self.table = new_table
  51. def __str__(self):
  52. return self.quote_name(self.table)
  53. class TableColumns(Table):
  54. """Base class for references to multiple columns of a table."""
  55. def __init__(self, table, columns):
  56. self.table = table
  57. self.columns = columns
  58. def references_column(self, table, column):
  59. return self.table == table and column in self.columns
  60. def rename_column_references(self, table, old_column, new_column):
  61. if self.table == table:
  62. for index, column in enumerate(self.columns):
  63. if column == old_column:
  64. self.columns[index] = new_column
  65. class Columns(TableColumns):
  66. """Hold a reference to one or many columns."""
  67. def __init__(self, table, columns, quote_name, col_suffixes=()):
  68. self.quote_name = quote_name
  69. self.col_suffixes = col_suffixes
  70. super().__init__(table, columns)
  71. def __str__(self):
  72. def col_str(column, idx):
  73. col = self.quote_name(column)
  74. try:
  75. suffix = self.col_suffixes[idx]
  76. if suffix:
  77. col = "{} {}".format(col, suffix)
  78. except IndexError:
  79. pass
  80. return col
  81. return ", ".join(
  82. col_str(column, idx) for idx, column in enumerate(self.columns)
  83. )
  84. class IndexName(TableColumns):
  85. """Hold a reference to an index name."""
  86. def __init__(self, table, columns, suffix, create_index_name):
  87. self.suffix = suffix
  88. self.create_index_name = create_index_name
  89. super().__init__(table, columns)
  90. def __str__(self):
  91. return self.create_index_name(self.table, self.columns, self.suffix)
  92. class IndexColumns(Columns):
  93. def __init__(self, table, columns, quote_name, col_suffixes=(), opclasses=()):
  94. self.opclasses = opclasses
  95. super().__init__(table, columns, quote_name, col_suffixes)
  96. def __str__(self):
  97. def col_str(column, idx):
  98. # Index.__init__() guarantees that self.opclasses is the same
  99. # length as self.columns.
  100. col = "{} {}".format(self.quote_name(column), self.opclasses[idx])
  101. try:
  102. suffix = self.col_suffixes[idx]
  103. if suffix:
  104. col = "{} {}".format(col, suffix)
  105. except IndexError:
  106. pass
  107. return col
  108. return ", ".join(
  109. col_str(column, idx) for idx, column in enumerate(self.columns)
  110. )
  111. class ForeignKeyName(TableColumns):
  112. """Hold a reference to a foreign key name."""
  113. def __init__(
  114. self,
  115. from_table,
  116. from_columns,
  117. to_table,
  118. to_columns,
  119. suffix_template,
  120. create_fk_name,
  121. ):
  122. self.to_reference = TableColumns(to_table, to_columns)
  123. self.suffix_template = suffix_template
  124. self.create_fk_name = create_fk_name
  125. super().__init__(
  126. from_table,
  127. from_columns,
  128. )
  129. def references_table(self, table):
  130. return super().references_table(table) or self.to_reference.references_table(
  131. table
  132. )
  133. def references_column(self, table, column):
  134. return super().references_column(
  135. table, column
  136. ) or self.to_reference.references_column(table, column)
  137. def rename_table_references(self, old_table, new_table):
  138. super().rename_table_references(old_table, new_table)
  139. self.to_reference.rename_table_references(old_table, new_table)
  140. def rename_column_references(self, table, old_column, new_column):
  141. super().rename_column_references(table, old_column, new_column)
  142. self.to_reference.rename_column_references(table, old_column, new_column)
  143. def __str__(self):
  144. suffix = self.suffix_template % {
  145. "to_table": self.to_reference.table,
  146. "to_column": self.to_reference.columns[0],
  147. }
  148. return self.create_fk_name(self.table, self.columns, suffix)
  149. class Statement(Reference):
  150. """
  151. Statement template and formatting parameters container.
  152. Allows keeping a reference to a statement without interpolating identifiers
  153. that might have to be adjusted if they're referencing a table or column
  154. that is removed
  155. """
  156. def __init__(self, template, **parts):
  157. self.template = template
  158. self.parts = parts
  159. def references_table(self, table):
  160. return any(
  161. hasattr(part, "references_table") and part.references_table(table)
  162. for part in self.parts.values()
  163. )
  164. def references_column(self, table, column):
  165. return any(
  166. hasattr(part, "references_column") and part.references_column(table, column)
  167. for part in self.parts.values()
  168. )
  169. def references_index(self, table, index):
  170. return any(
  171. hasattr(part, "references_index") and part.references_index(table, index)
  172. for part in self.parts.values()
  173. )
  174. def rename_table_references(self, old_table, new_table):
  175. for part in self.parts.values():
  176. if hasattr(part, "rename_table_references"):
  177. part.rename_table_references(old_table, new_table)
  178. def rename_column_references(self, table, old_column, new_column):
  179. for part in self.parts.values():
  180. if hasattr(part, "rename_column_references"):
  181. part.rename_column_references(table, old_column, new_column)
  182. def __str__(self):
  183. return self.template % self.parts
  184. class Expressions(TableColumns):
  185. def __init__(self, table, expressions, compiler, quote_value):
  186. self.compiler = compiler
  187. self.expressions = expressions
  188. self.quote_value = quote_value
  189. columns = [
  190. col.target.column
  191. for col in self.compiler.query._gen_cols([self.expressions])
  192. ]
  193. super().__init__(table, columns)
  194. def rename_table_references(self, old_table, new_table):
  195. if self.table != old_table:
  196. return
  197. self.expressions = self.expressions.relabeled_clone({old_table: new_table})
  198. super().rename_table_references(old_table, new_table)
  199. def rename_column_references(self, table, old_column, new_column):
  200. if self.table != table:
  201. return
  202. expressions = deepcopy(self.expressions)
  203. self.columns = []
  204. for col in self.compiler.query._gen_cols([expressions]):
  205. if col.target.column == old_column:
  206. col.target.column = new_column
  207. self.columns.append(col.target.column)
  208. self.expressions = expressions
  209. def __str__(self):
  210. sql, params = self.compiler.compile(self.expressions)
  211. params = map(self.quote_value, params)
  212. return sql % tuple(params)