Skip to content

Commit f7d1699

Browse files
extermmatzbot
authored andcommitted
[ruby/prism] Implement case equality on nodes
ruby/prism@dc121e4fdf
1 parent 87b829a commit f7d1699

File tree

2 files changed

+34
-4
lines changed

2 files changed

+34
-4
lines changed

prism/templates/lib/prism/node.rb.erb

Lines changed: 19 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -219,10 +219,10 @@ module Prism
219219
def deconstruct_keys(keys)
220220
{ <%= (node.fields.map { |field| "#{field.name}: #{field.name}" } + ["location: location"]).join(", ") %> }
221221
end
222-
223222
<%- node.fields.each do |field| -%>
223+
224224
<%- if field.comment.nil? -%>
225-
# <%= "private " if field.is_a?(Prism::Template::FlagsField) %>attr_reader <%= field.name %>: <%= field.rbs_class %>
225+
# <%= "protected " if field.is_a?(Prism::Template::FlagsField) %>attr_reader <%= field.name %>: <%= field.rbs_class %>
226226
<%- else -%>
227227
<%- field.each_comment_line do |line| -%>
228228
#<%= line %>
@@ -248,9 +248,8 @@ module Prism
248248
end
249249
end
250250
<%- else -%>
251-
attr_reader :<%= field.name -%><%= "\n private :#{field.name}" if field.is_a?(Prism::Template::FlagsField) %>
251+
attr_reader :<%= field.name -%><%= "\n protected :#{field.name}" if field.is_a?(Prism::Template::FlagsField) %>
252252
<%- end -%>
253-
254253
<%- end -%>
255254
<%- node.fields.each do |field| -%>
256255
<%- case field -%>
@@ -349,6 +348,22 @@ module Prism
349348
def self.type
350349
:<%= node.human %>
351350
end
351+
352+
# Implements case-equality for the node. This is effectively == but without
353+
# comparing the value of locations. Locations are checked only for presence.
354+
def ===(other)
355+
other.is_a?(<%= node.name %>)<%= " &&" if node.fields.any? %>
356+
<%- node.fields.each_with_index do |field, index| -%>
357+
<%- if field.is_a?(Prism::Template::LocationField) || field.is_a?(Prism::Template::OptionalLocationField) -%>
358+
(<%= field.name %>.nil? == other.<%= field.name %>.nil?)<%= " &&" if index != node.fields.length - 1 %>
359+
<%- elsif field.is_a?(Prism::Template::NodeListField) || field.is_a?(Prism::Template::ConstantListField) -%>
360+
(<%= field.name %>.length == other.<%= field.name %>.length) &&
361+
<%= field.name %>.zip(other.<%= field.name %>).all? { |left, right| left === right }<%= " &&" if index != node.fields.length - 1 %>
362+
<%- else -%>
363+
(<%= field.name %> === other.<%= field.name %>)<%= " &&" if index != node.fields.length - 1 %>
364+
<%- end -%>
365+
<%- end -%>
366+
end
352367
end
353368
<%- end -%>
354369
<%- flags.each_with_index do |flag, flag_index| -%>

test/prism/ruby_api_test.rb

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -244,6 +244,21 @@ def test_integer_base_flags
244244
assert_equal 16, base[parse_expression("0x1")]
245245
end
246246

247+
def test_node_equality
248+
assert_operator parse_expression("1"), :===, parse_expression("1")
249+
assert_operator Prism.parse("1").value, :===, Prism.parse("1").value
250+
251+
complex_source = "class Something; @var = something.else { _1 }; end"
252+
assert_operator parse_expression(complex_source), :===, parse_expression(complex_source)
253+
254+
refute_operator parse_expression("1"), :===, parse_expression("2")
255+
refute_operator parse_expression("1"), :===, parse_expression("0x1")
256+
257+
complex_source_1 = "class Something; @var = something.else { _1 }; end"
258+
complex_source_2 = "class Something; @var = something.else { _2 }; end"
259+
refute_operator parse_expression(complex_source_1), :===, parse_expression(complex_source_2)
260+
end
261+
247262
private
248263

249264
def parse_expression(source)

0 commit comments

Comments
 (0)