From ddcf216453981bf8abac7daea1908610c022dd2f Mon Sep 17 00:00:00 2001 From: Nick Stanish <1869157+nickstanish@users.noreply.github.com> Date: Fri, 30 Jan 2026 14:25:45 -0600 Subject: [PATCH 1/2] store authorize state in list --- lib/omniauth/strategies/oauth2.rb | 37 ++++++++++- spec/omniauth/strategies/oauth2_spec.rb | 86 +++++++++++++++++++++++-- 2 files changed, 114 insertions(+), 9 deletions(-) diff --git a/lib/omniauth/strategies/oauth2.rb b/lib/omniauth/strategies/oauth2.rb index 1588926..6d84684 100644 --- a/lib/omniauth/strategies/oauth2.rb +++ b/lib/omniauth/strategies/oauth2.rb @@ -72,7 +72,7 @@ def authorize_params # rubocop:disable Metrics/AbcSize, Metrics/MethodLength .merge(pkce_authorize_params) session["omniauth.pkce.verifier"] = options.pkce_verifier if options.pkce - session["omniauth.state"] = params[:state] + store_state(params[:state]) params end @@ -83,7 +83,7 @@ def token_params def callback_phase # rubocop:disable Metrics/AbcSize, Metrics/CyclomaticComplexity, Metrics/MethodLength, Metrics/PerceivedComplexity error = request.params["error_reason"] || request.params["error"] - if !options.provider_ignores_state && (request.params["state"].to_s.empty? || !secure_compare(request.params["state"], session.delete("omniauth.state"))) + if !options.provider_ignores_state && (request.params["state"].to_s.empty? || !valid_state?(request.params["state"])) fail!(:csrf_detected, CallbackError.new(:csrf_detected, "CSRF detected")) elsif error fail!(error, CallbackError.new(request.params["error"], request.params["error_description"] || request.params["error_reason"], request.params["error_uri"])) @@ -102,6 +102,39 @@ def callback_phase # rubocop:disable Metrics/AbcSize, Metrics/CyclomaticComplexi protected + def copy_legacy_session_state + return unless session.key?("omniauth.state") + old_state = session.delete('omniauth.state') + + return unless old_state + + session["omniauth.states"] = Array(session["omniauth.states"]) + Array(old_state) + end + + def store_state(state) + copy_legacy_session_state + session["omniauth.states"] = Array(session["omniauth.states"]) + session["omniauth.states"] << state + end + + def valid_state?(state) + copy_legacy_session_state + + Array(session["omniauth.states"]).each_with_index do |stored_state, index| + if secure_compare(state, stored_state) + delete_state(stored_state) + return true + end + end + + false + end + + def delete_state(state) + session["omniauth.states"] = Array(session["omniauth.states"]) + session["omniauth.states"] = (session["omniauth.states"] - [state]).uniq + end + def pkce_authorize_params return {} unless options.pkce diff --git a/spec/omniauth/strategies/oauth2_spec.rb b/spec/omniauth/strategies/oauth2_spec.rb index a267773..6062acc 100644 --- a/spec/omniauth/strategies/oauth2_spec.rb +++ b/spec/omniauth/strategies/oauth2_spec.rb @@ -58,13 +58,31 @@ def app it "includes random state in the authorize params" do instance = subject.new("abc", "def") expect(instance.authorize_params.keys).to eq(["state"]) - expect(instance.session["omniauth.state"]).not_to be_empty + expect(instance.session["omniauth.states"]).to include(instance.authorize_params["state"]) end it "includes custom state in the authorize params" do instance = subject.new("abc", "def", :state => proc { "qux" }) expect(instance.authorize_params.keys).to eq(["state"]) - expect(instance.session["omniauth.state"]).to eq("qux") + expect(instance.session["omniauth.states"]).to include("qux") + end + + it "supports multiple concurrent states" do + instance = subject.new("abc", "def") + state1 = instance.authorize_params["state"] + state2 = instance.authorize_params["state"] + state3 = instance.authorize_params["state"] + expect(instance.session["omniauth.states"]).to eq([state1, state2, state3]) + end + + it "migrates old single state to states array" do + instance = subject.new("abc", "def") + instance.authorize_params + instance.session["omniauth.state"] = "old_state" + instance.session.delete("omniauth.states") + new_state = instance.authorize_params["state"] + expect(instance.session["omniauth.states"]).to eq(["old_state", new_state]) + expect(instance.session["omniauth.state"]).to be_nil end it "includes PKCE parameters if enabled" do @@ -106,21 +124,46 @@ def app allow(instance).to receive(:request) do double("Request", :params => params) end + end + + context "with new states array format" do + it "calls fail with the error received" do + session = {"omniauth.states" => [state]} + allow(instance).to receive(:session).and_return(session) + expect(instance).to receive(:fail!).with("user_denied", anything) + instance.callback_phase + end - allow(instance).to receive(:session) do - double("Session", :delete => state) + it "removes the validated state from the states array" do + session = {"omniauth.states" => [state, "other_state"]} + allow(instance).to receive(:session).and_return(session) + expect(instance).to receive(:fail!).with("user_denied", anything) + instance.callback_phase + expect(session["omniauth.states"]).to eq(["other_state"]) end end - it "calls fail with the error received" do - expect(instance).to receive(:fail!).with("user_denied", anything) + context "with legacy single state format" do + it "calls fail with the error received" do + session = {"omniauth.state" => state} + allow(instance).to receive(:session).and_return(session) + expect(instance).to receive(:fail!).with("user_denied", anything) + instance.callback_phase + end - instance.callback_phase + it "removes the old state key after validation" do + session = {"omniauth.state" => state} + allow(instance).to receive(:session).and_return(session) + expect(instance).to receive(:fail!).with("user_denied", anything) + instance.callback_phase + expect(session).not_to have_key("omniauth.state") + end end it "calls fail with the error received if state is missing and CSRF verification is disabled" do params["state"] = nil instance.options.provider_ignores_state = true + allow(instance).to receive(:session).and_return({}) expect(instance).to receive(:fail!).with("user_denied", anything) @@ -129,6 +172,7 @@ def app it "calls fail with a CSRF error if the state is missing" do params["state"] = nil + allow(instance).to receive(:session).and_return({}) expect(instance).to receive(:fail!).with(:csrf_detected, anything) instance.callback_phase @@ -136,17 +180,45 @@ def app it "calls fail with a CSRF error if the state is invalid" do params["state"] = "invalid" + allow(instance).to receive(:session).and_return({"omniauth.states" => [state]}) expect(instance).to receive(:fail!).with(:csrf_detected, anything) instance.callback_phase end + it "validates concurrent states correctly" do + state1 = "state1" + state2 = "state2" + state3 = "state3" + params["state"] = state2 + session = {"omniauth.states" => [state1, state2, state3]} + allow(instance).to receive(:session).and_return(session) + + expect(instance).to receive(:fail!).with("user_denied", anything) + instance.callback_phase + expect(session["omniauth.states"]).to eq([state1, state3]) + end + + it "removes duplicate states" do + state1 = "state1" + state2 = "state2" + state3 = "state3" + params["state"] = state2 + session = {"omniauth.states" => [state1, state1, state2, state2, state3, state3]} + allow(instance).to receive(:session).and_return(session) + + expect(instance).to receive(:fail!).with("user_denied", anything) + instance.callback_phase + expect(session["omniauth.states"]).to eq([state1, state3]) + end + describe 'exception handlings' do let(:params) do {"code" => "code", "state" => state} end before do + allow(instance).to receive(:session).and_return({"omniauth.states" => [state]}) allow_any_instance_of(OmniAuth::Strategies::OAuth2).to receive(:build_access_token).and_raise(exception) end From 76393fb2b0b8807c9ff8e1321be1ff3c21d8e7af Mon Sep 17 00:00:00 2001 From: Nick Stanish <1869157+nickstanish@users.noreply.github.com> Date: Fri, 30 Jan 2026 15:16:46 -0600 Subject: [PATCH 2/2] use hash session instead of array to store pkce and exp values --- lib/omniauth/strategies/oauth2.rb | 68 ++++++++++++++------ spec/omniauth/strategies/oauth2_spec.rb | 82 +++++++++++++++++-------- 2 files changed, 105 insertions(+), 45 deletions(-) diff --git a/lib/omniauth/strategies/oauth2.rb b/lib/omniauth/strategies/oauth2.rb index 6d84684..2bd96ff 100644 --- a/lib/omniauth/strategies/oauth2.rb +++ b/lib/omniauth/strategies/oauth2.rb @@ -71,8 +71,9 @@ def authorize_params # rubocop:disable Metrics/AbcSize, Metrics/MethodLength .merge(options_for("authorize")) .merge(pkce_authorize_params) - session["omniauth.pkce.verifier"] = options.pkce_verifier if options.pkce - store_state(params[:state]) + metadata = {} + metadata["pkce_verifier"] = options.pkce_verifier if options.pkce + store_state(params[:state], metadata) params end @@ -90,6 +91,7 @@ def callback_phase # rubocop:disable Metrics/AbcSize, Metrics/CyclomaticComplexi else self.access_token = build_access_token self.access_token = access_token.refresh! if access_token.expired? + cleanup_expired_state super end rescue ::OAuth2::Error, CallbackError => e @@ -102,37 +104,61 @@ def callback_phase # rubocop:disable Metrics/AbcSize, Metrics/CyclomaticComplexi protected - def copy_legacy_session_state + def migrate_legacy_session return unless session.key?("omniauth.state") - old_state = session.delete('omniauth.state') + old_state = session.delete("omniauth.state") + old_verifier = session.delete("omniauth.pkce.verifier") return unless old_state - session["omniauth.states"] = Array(session["omniauth.states"]) + Array(old_state) + session["omniauth.oauth2_states"] ||= {} + session["omniauth.oauth2_states"][old_state] = {"iat" => Time.now.to_i, "exp" => nil} + session["omniauth.oauth2_states"][old_state]["pkce_verifier"] = old_verifier if old_verifier end - def store_state(state) - copy_legacy_session_state - session["omniauth.states"] = Array(session["omniauth.states"]) - session["omniauth.states"] << state + def store_state(state, metadata = {}) + migrate_legacy_session + session["omniauth.oauth2_states"] ||= {} + session["omniauth.oauth2_states"][state] = metadata.merge("iat" => Time.now.to_i, "exp" => nil) end - def valid_state?(state) - copy_legacy_session_state + def find_state(state) + session["omniauth.oauth2_states"] ||= {} - Array(session["omniauth.states"]).each_with_index do |stored_state, index| + session["omniauth.oauth2_states"].each_pair do |stored_state, metadata| if secure_compare(state, stored_state) - delete_state(stored_state) - return true + return [stored_state, metadata] end end - false + nil end - def delete_state(state) - session["omniauth.states"] = Array(session["omniauth.states"]) - session["omniauth.states"] = (session["omniauth.states"] - [state]).uniq + def valid_state?(state) + migrate_legacy_session + found = find_state(state) + + return false unless found + + stored_state, metadata = found + + return false if metadata["exp"] + + session["omniauth.oauth2_states"][stored_state]["exp"] = Time.now.to_i + + true + end + + def get_state_metadata(state) + found = find_state(state) + found ? found[1] : nil + end + + def cleanup_expired_state + session["omniauth.oauth2_states"] ||= {} + session["omniauth.oauth2_states"].delete_if do |_state, metadata| + metadata["exp"] + end end def pkce_authorize_params @@ -151,7 +177,11 @@ def pkce_authorize_params def pkce_token_params return {} unless options.pkce - {:code_verifier => session.delete("omniauth.pkce.verifier")} + state = request.params["state"] + metadata = get_state_metadata(state) + verifier = metadata && metadata["pkce_verifier"] + + {:code_verifier => verifier} end def build_access_token diff --git a/spec/omniauth/strategies/oauth2_spec.rb b/spec/omniauth/strategies/oauth2_spec.rb index 6062acc..14cdd27 100644 --- a/spec/omniauth/strategies/oauth2_spec.rb +++ b/spec/omniauth/strategies/oauth2_spec.rb @@ -58,13 +58,14 @@ def app it "includes random state in the authorize params" do instance = subject.new("abc", "def") expect(instance.authorize_params.keys).to eq(["state"]) - expect(instance.session["omniauth.states"]).to include(instance.authorize_params["state"]) + state = instance.authorize_params["state"] + expect(instance.session["omniauth.oauth2_states"]).to have_key(state) end it "includes custom state in the authorize params" do instance = subject.new("abc", "def", :state => proc { "qux" }) expect(instance.authorize_params.keys).to eq(["state"]) - expect(instance.session["omniauth.states"]).to include("qux") + expect(instance.session["omniauth.oauth2_states"]).to have_key("qux") end it "supports multiple concurrent states" do @@ -72,24 +73,26 @@ def app state1 = instance.authorize_params["state"] state2 = instance.authorize_params["state"] state3 = instance.authorize_params["state"] - expect(instance.session["omniauth.states"]).to eq([state1, state2, state3]) + expect(instance.session["omniauth.oauth2_states"].keys).to match_array([state1, state2, state3]) end - it "migrates old single state to states array" do + it "migrates old single state to states hash" do instance = subject.new("abc", "def") instance.authorize_params instance.session["omniauth.state"] = "old_state" - instance.session.delete("omniauth.states") + instance.session.delete("omniauth.oauth2_states") new_state = instance.authorize_params["state"] - expect(instance.session["omniauth.states"]).to eq(["old_state", new_state]) + expect(instance.session["omniauth.oauth2_states"].keys).to match_array(["old_state", new_state]) expect(instance.session["omniauth.state"]).to be_nil end it "includes PKCE parameters if enabled" do instance = subject.new("abc", "def", :pkce => true) - expect(instance.authorize_params[:code_challenge]).to be_a(String) - expect(instance.authorize_params[:code_challenge_method]).to eq("S256") - expect(instance.session["omniauth.pkce.verifier"]).to be_a(String) + params = instance.authorize_params + expect(params[:code_challenge]).to be_a(String) + expect(params[:code_challenge_method]).to eq("S256") + state = params["state"] + expect(instance.session["omniauth.oauth2_states"][state]["pkce_verifier"]).to be_a(String) end end @@ -108,8 +111,9 @@ def app it "includes the PKCE code_verifier if enabled" do instance = subject.new("abc", "def", :pkce => true) - # setup session - instance.authorize_params + params = instance.authorize_params + state = params["state"] + allow(instance).to receive(:request).and_return(double("Request", :params => {"state" => state})) expect(instance.token_params[:code_verifier]).to be_a(String) end end @@ -126,20 +130,24 @@ def app end end - context "with new states array format" do + context "with new states hash format" do it "calls fail with the error received" do - session = {"omniauth.states" => [state]} + session = {"omniauth.oauth2_states" => {state => {"iat" => Time.now.to_i, "exp" => nil}}} allow(instance).to receive(:session).and_return(session) expect(instance).to receive(:fail!).with("user_denied", anything) instance.callback_phase end - it "removes the validated state from the states array" do - session = {"omniauth.states" => [state, "other_state"]} + it "marks the validated state as expired" do + session = {"omniauth.oauth2_states" => { + state => {"iat" => Time.now.to_i, "exp" => nil}, + "other_state" => {"iat" => Time.now.to_i, "exp" => nil} + }} allow(instance).to receive(:session).and_return(session) expect(instance).to receive(:fail!).with("user_denied", anything) instance.callback_phase - expect(session["omniauth.states"]).to eq(["other_state"]) + expect(session["omniauth.oauth2_states"][state]["exp"]).not_to be_nil + expect(session["omniauth.oauth2_states"]["other_state"]["exp"]).to be_nil end end @@ -180,7 +188,7 @@ def app it "calls fail with a CSRF error if the state is invalid" do params["state"] = "invalid" - allow(instance).to receive(:session).and_return({"omniauth.states" => [state]}) + allow(instance).to receive(:session).and_return({"omniauth.oauth2_states" => {state => {"iat" => Time.now.to_i, "exp" => nil}}}) expect(instance).to receive(:fail!).with(:csrf_detected, anything) instance.callback_phase @@ -191,25 +199,47 @@ def app state2 = "state2" state3 = "state3" params["state"] = state2 - session = {"omniauth.states" => [state1, state2, state3]} + session = {"omniauth.oauth2_states" => { + state1 => {"iat" => Time.now.to_i, "exp" => nil}, + state2 => {"iat" => Time.now.to_i, "exp" => nil}, + state3 => {"iat" => Time.now.to_i, "exp" => nil} + }} allow(instance).to receive(:session).and_return(session) expect(instance).to receive(:fail!).with("user_denied", anything) instance.callback_phase - expect(session["omniauth.states"]).to eq([state1, state3]) + expect(session["omniauth.oauth2_states"][state2]["exp"]).not_to be_nil + expect(session["omniauth.oauth2_states"][state1]["exp"]).to be_nil + expect(session["omniauth.oauth2_states"][state3]["exp"]).to be_nil end - it "removes duplicate states" do + it "prevents replay attacks by rejecting already-used states" do + params["state"] = state + session = {"omniauth.oauth2_states" => { + state => {"iat" => Time.now.to_i, "exp" => Time.now.to_i - 60} + }} + allow(instance).to receive(:session).and_return(session) + + expect(instance).to receive(:fail!).with(:csrf_detected, anything) + instance.callback_phase + end + + it "cleans up all states with exp set" do state1 = "state1" state2 = "state2" state3 = "state3" - params["state"] = state2 - session = {"omniauth.states" => [state1, state1, state2, state2, state3, state3]} + session = {"omniauth.oauth2_states" => { + state1 => {"iat" => Time.now.to_i, "exp" => Time.now.to_i}, + state2 => {"iat" => Time.now.to_i, "exp" => Time.now.to_i - 60}, + state3 => {"iat" => Time.now.to_i, "exp" => nil} + }} allow(instance).to receive(:session).and_return(session) - expect(instance).to receive(:fail!).with("user_denied", anything) - instance.callback_phase - expect(session["omniauth.states"]).to eq([state1, state3]) + instance.send(:cleanup_expired_state) + + expect(session["omniauth.oauth2_states"]).not_to have_key(state1) + expect(session["omniauth.oauth2_states"]).not_to have_key(state2) + expect(session["omniauth.oauth2_states"]).to have_key(state3) end describe 'exception handlings' do @@ -218,7 +248,7 @@ def app end before do - allow(instance).to receive(:session).and_return({"omniauth.states" => [state]}) + allow(instance).to receive(:session).and_return({"omniauth.oauth2_states" => {state => {"iat" => Time.now.to_i, "exp" => nil}}}) allow_any_instance_of(OmniAuth::Strategies::OAuth2).to receive(:build_access_token).and_raise(exception) end