lib/baza/sql_queries/postgres_upsert_duplicate_key.rb
class Baza::SqlQueries::PostgresUpsertDuplicateKey
def initialize(args)
@db = args.fetch(:db)
@table_name = args.fetch(:table_name)
@updates = StringCases.stringify_keys(args.fetch(:updates))
@terms = StringCases.stringify_keys(args.fetch(:terms))
@return_id = args[:return_id]
end
def execute
if @db.commands.version.to_f >= 9.5 && @db.commands.version.to_f <= 9.5
@db.query(on_conflict_sql)
elsif @terms.empty?
return insert_and_register_conflict
else
@db.query(begin_update_exception_sql)
end
@db.last_id if @return_id
end
private
def insert_and_register_conflict
@db.query(insert_sql)
@db.last_id if @return_id
rescue => e
if (match = e.message.match(/Key \((.+)\)=\((.+)\) already exists/))
column_name = match[1]
conflicting_value = match[2]
@terms = {column_name => conflicting_value}
@db.query(begin_update_exception_sql)
if @return_id
primary_column = table.columns.find(&:primarykey?).name.to_sym
data = @db.single(@table_name, column_name => conflicting_value)
return data.fetch(primary_column).to_i
end
else
raise e
end
end
def begin_update_exception_sql
sql = "do $$\n"
sql << "BEGIN\n"
sql << "\t#{insert_sql};\n"
sql << "EXCEPTION WHEN unique_violation THEN\n"
sql << "\t#{update_sql};\n"
sql << "END $$;"
sql
end
def on_conflict_sql
"#{insert_sql} ON CONFLICT (#{conflict_column_sql}) DO UPDATE #{update_set_sql}"
end
def conflict_column_sql
sql = ""
first = true
@updates.keys.each do |column_name|
sql << ", " unless first
first = false if first
sql << @db.quote_column(column_name)
end
sql
end
def insert_sql
sql = "INSERT INTO #{@db.quote_table(@table_name)} ("
combined_data = @updates.merge(@terms)
first = true
combined_data.each_key do |column_name|
sql << ", " unless first
first = false if first
sql << @db.quote_column(column_name)
end
sql << ") VALUES ("
first = true
combined_data.each_value do |value|
sql << ", " unless first
first = false if first
sql << @db.quote_value(value).to_s
end
sql << ")"
sql
end
def update_sql
"UPDATE #{@db.quote_table(@table_name)} #{update_set_sql} #{update_where_sql}"
end
def update_set_sql
sql = "SET "
first = true
@updates.each do |key, value|
sql << ", " unless first
first = false if first
sql << "#{@db.quote_column(key)} = #{@db.quote_value(value)}"
end
sql
end
def update_where_sql
sql = "WHERE "
first = true
@terms.each do |key, value|
sql << " AND " unless first
first = false if first
sql << "#{@db.quote_column(key)} = #{@db.quote_value(value)}"
end
sql
end
def table
@table ||= @db.tables[@table_name.to_s]
end
end