@@ -219,10 +219,10 @@ module Prism
219219 def deconstruct_keys(keys)
220220 { <%= ( node . fields . map { |field | "#{ field . name } : #{ field . name } " } + [ "location: location" ] ) . join ( ", " ) %> }
221221 end
222-
223222 <%- node . fields . each do |field | -%>
223+
224224 <%- if field . comment . nil? -%>
225- # <%= "private " if field . is_a? ( Prism ::Template ::FlagsField ) %> attr_reader <%= field . name %> : <%= field . rbs_class %>
225+ # <%= "protected " if field . is_a? ( Prism ::Template ::FlagsField ) %> attr_reader <%= field . name %> : <%= field . rbs_class %>
226226 <%- else -%>
227227 <%- field . each_comment_line do |line | -%>
228228 #<%= line %>
@@ -248,9 +248,8 @@ module Prism
248248 end
249249 end
250250 <%- else -%>
251- attr_reader :<%= field . name -%> <%= "\n private :#{ field . name } " if field . is_a? ( Prism ::Template ::FlagsField ) %>
251+ attr_reader :<%= field . name -%> <%= "\n protected :#{ field . name } " if field . is_a? ( Prism ::Template ::FlagsField ) %>
252252 <%- end -%>
253-
254253 <%- end -%>
255254 <%- node . fields . each do |field | -%>
256255 <%- case field -%>
@@ -349,6 +348,22 @@ module Prism
349348 def self.type
350349 :<%= node . human %>
351350 end
351+
352+ # Implements case-equality for the node. This is effectively == but without
353+ # comparing the value of locations. Locations are checked only for presence.
354+ def ===(other)
355+ other.is_a?(<%= node . name %> )<%= " &&" if node . fields . any? %>
356+ <%- node . fields . each_with_index do |field , index | -%>
357+ <%- if field . is_a? ( Prism ::Template ::LocationField ) || field . is_a? ( Prism ::Template ::OptionalLocationField ) -%>
358+ (<%= field . name %> .nil? == other.<%= field . name %> .nil?)<%= " &&" if index != node . fields . length - 1 %>
359+ <%- elsif field . is_a? ( Prism ::Template ::NodeListField ) || field . is_a? ( Prism ::Template ::ConstantListField ) -%>
360+ (<%= field . name %> .length == other.<%= field . name %> .length) &&
361+ <%= field . name %> .zip(other.<%= field . name %> ).all? { |left, right| left === right }<%= " &&" if index != node . fields . length - 1 %>
362+ <%- else -%>
363+ (<%= field . name %> === other.<%= field . name %> )<%= " &&" if index != node . fields . length - 1 %>
364+ <%- end -%>
365+ <%- end -%>
366+ end
352367 end
353368 <%- end -%>
354369 <%- flags . each_with_index do |flag , flag_index | -%>
0 commit comments