diff --git a/lib/omniauth/strategies/oauth2.rb b/lib/omniauth/strategies/oauth2.rb index 1588926..2bd96ff 100644 --- a/lib/omniauth/strategies/oauth2.rb +++ b/lib/omniauth/strategies/oauth2.rb @@ -71,8 +71,9 @@ def authorize_params # rubocop:disable Metrics/AbcSize, Metrics/MethodLength .merge(options_for("authorize")) .merge(pkce_authorize_params) - session["omniauth.pkce.verifier"] = options.pkce_verifier if options.pkce - session["omniauth.state"] = params[:state] + metadata = {} + metadata["pkce_verifier"] = options.pkce_verifier if options.pkce + store_state(params[:state], metadata) params end @@ -83,13 +84,14 @@ def token_params def callback_phase # rubocop:disable Metrics/AbcSize, Metrics/CyclomaticComplexity, Metrics/MethodLength, Metrics/PerceivedComplexity error = request.params["error_reason"] || request.params["error"] - if !options.provider_ignores_state && (request.params["state"].to_s.empty? || !secure_compare(request.params["state"], session.delete("omniauth.state"))) + if !options.provider_ignores_state && (request.params["state"].to_s.empty? || !valid_state?(request.params["state"])) fail!(:csrf_detected, CallbackError.new(:csrf_detected, "CSRF detected")) elsif error fail!(error, CallbackError.new(request.params["error"], request.params["error_description"] || request.params["error_reason"], request.params["error_uri"])) else self.access_token = build_access_token self.access_token = access_token.refresh! if access_token.expired? + cleanup_expired_state super end rescue ::OAuth2::Error, CallbackError => e @@ -102,6 +104,63 @@ def callback_phase # rubocop:disable Metrics/AbcSize, Metrics/CyclomaticComplexi protected + def migrate_legacy_session + return unless session.key?("omniauth.state") + old_state = session.delete("omniauth.state") + old_verifier = session.delete("omniauth.pkce.verifier") + + return unless old_state + + session["omniauth.oauth2_states"] ||= {} + session["omniauth.oauth2_states"][old_state] = {"iat" => Time.now.to_i, "exp" => nil} + session["omniauth.oauth2_states"][old_state]["pkce_verifier"] = old_verifier if old_verifier + end + + def store_state(state, metadata = {}) + migrate_legacy_session + session["omniauth.oauth2_states"] ||= {} + session["omniauth.oauth2_states"][state] = metadata.merge("iat" => Time.now.to_i, "exp" => nil) + end + + def find_state(state) + session["omniauth.oauth2_states"] ||= {} + + session["omniauth.oauth2_states"].each_pair do |stored_state, metadata| + if secure_compare(state, stored_state) + return [stored_state, metadata] + end + end + + nil + end + + def valid_state?(state) + migrate_legacy_session + found = find_state(state) + + return false unless found + + stored_state, metadata = found + + return false if metadata["exp"] + + session["omniauth.oauth2_states"][stored_state]["exp"] = Time.now.to_i + + true + end + + def get_state_metadata(state) + found = find_state(state) + found ? found[1] : nil + end + + def cleanup_expired_state + session["omniauth.oauth2_states"] ||= {} + session["omniauth.oauth2_states"].delete_if do |_state, metadata| + metadata["exp"] + end + end + def pkce_authorize_params return {} unless options.pkce @@ -118,7 +177,11 @@ def pkce_authorize_params def pkce_token_params return {} unless options.pkce - {:code_verifier => session.delete("omniauth.pkce.verifier")} + state = request.params["state"] + metadata = get_state_metadata(state) + verifier = metadata && metadata["pkce_verifier"] + + {:code_verifier => verifier} end def build_access_token diff --git a/spec/omniauth/strategies/oauth2_spec.rb b/spec/omniauth/strategies/oauth2_spec.rb index a267773..14cdd27 100644 --- a/spec/omniauth/strategies/oauth2_spec.rb +++ b/spec/omniauth/strategies/oauth2_spec.rb @@ -58,20 +58,41 @@ def app it "includes random state in the authorize params" do instance = subject.new("abc", "def") expect(instance.authorize_params.keys).to eq(["state"]) - expect(instance.session["omniauth.state"]).not_to be_empty + state = instance.authorize_params["state"] + expect(instance.session["omniauth.oauth2_states"]).to have_key(state) end it "includes custom state in the authorize params" do instance = subject.new("abc", "def", :state => proc { "qux" }) expect(instance.authorize_params.keys).to eq(["state"]) - expect(instance.session["omniauth.state"]).to eq("qux") + expect(instance.session["omniauth.oauth2_states"]).to have_key("qux") + end + + it "supports multiple concurrent states" do + instance = subject.new("abc", "def") + state1 = instance.authorize_params["state"] + state2 = instance.authorize_params["state"] + state3 = instance.authorize_params["state"] + expect(instance.session["omniauth.oauth2_states"].keys).to match_array([state1, state2, state3]) + end + + it "migrates old single state to states hash" do + instance = subject.new("abc", "def") + instance.authorize_params + instance.session["omniauth.state"] = "old_state" + instance.session.delete("omniauth.oauth2_states") + new_state = instance.authorize_params["state"] + expect(instance.session["omniauth.oauth2_states"].keys).to match_array(["old_state", new_state]) + expect(instance.session["omniauth.state"]).to be_nil end it "includes PKCE parameters if enabled" do instance = subject.new("abc", "def", :pkce => true) - expect(instance.authorize_params[:code_challenge]).to be_a(String) - expect(instance.authorize_params[:code_challenge_method]).to eq("S256") - expect(instance.session["omniauth.pkce.verifier"]).to be_a(String) + params = instance.authorize_params + expect(params[:code_challenge]).to be_a(String) + expect(params[:code_challenge_method]).to eq("S256") + state = params["state"] + expect(instance.session["omniauth.oauth2_states"][state]["pkce_verifier"]).to be_a(String) end end @@ -90,8 +111,9 @@ def app it "includes the PKCE code_verifier if enabled" do instance = subject.new("abc", "def", :pkce => true) - # setup session - instance.authorize_params + params = instance.authorize_params + state = params["state"] + allow(instance).to receive(:request).and_return(double("Request", :params => {"state" => state})) expect(instance.token_params[:code_verifier]).to be_a(String) end end @@ -106,21 +128,50 @@ def app allow(instance).to receive(:request) do double("Request", :params => params) end + end + + context "with new states hash format" do + it "calls fail with the error received" do + session = {"omniauth.oauth2_states" => {state => {"iat" => Time.now.to_i, "exp" => nil}}} + allow(instance).to receive(:session).and_return(session) + expect(instance).to receive(:fail!).with("user_denied", anything) + instance.callback_phase + end - allow(instance).to receive(:session) do - double("Session", :delete => state) + it "marks the validated state as expired" do + session = {"omniauth.oauth2_states" => { + state => {"iat" => Time.now.to_i, "exp" => nil}, + "other_state" => {"iat" => Time.now.to_i, "exp" => nil} + }} + allow(instance).to receive(:session).and_return(session) + expect(instance).to receive(:fail!).with("user_denied", anything) + instance.callback_phase + expect(session["omniauth.oauth2_states"][state]["exp"]).not_to be_nil + expect(session["omniauth.oauth2_states"]["other_state"]["exp"]).to be_nil end end - it "calls fail with the error received" do - expect(instance).to receive(:fail!).with("user_denied", anything) + context "with legacy single state format" do + it "calls fail with the error received" do + session = {"omniauth.state" => state} + allow(instance).to receive(:session).and_return(session) + expect(instance).to receive(:fail!).with("user_denied", anything) + instance.callback_phase + end - instance.callback_phase + it "removes the old state key after validation" do + session = {"omniauth.state" => state} + allow(instance).to receive(:session).and_return(session) + expect(instance).to receive(:fail!).with("user_denied", anything) + instance.callback_phase + expect(session).not_to have_key("omniauth.state") + end end it "calls fail with the error received if state is missing and CSRF verification is disabled" do params["state"] = nil instance.options.provider_ignores_state = true + allow(instance).to receive(:session).and_return({}) expect(instance).to receive(:fail!).with("user_denied", anything) @@ -129,6 +180,7 @@ def app it "calls fail with a CSRF error if the state is missing" do params["state"] = nil + allow(instance).to receive(:session).and_return({}) expect(instance).to receive(:fail!).with(:csrf_detected, anything) instance.callback_phase @@ -136,17 +188,67 @@ def app it "calls fail with a CSRF error if the state is invalid" do params["state"] = "invalid" + allow(instance).to receive(:session).and_return({"omniauth.oauth2_states" => {state => {"iat" => Time.now.to_i, "exp" => nil}}}) expect(instance).to receive(:fail!).with(:csrf_detected, anything) instance.callback_phase end + it "validates concurrent states correctly" do + state1 = "state1" + state2 = "state2" + state3 = "state3" + params["state"] = state2 + session = {"omniauth.oauth2_states" => { + state1 => {"iat" => Time.now.to_i, "exp" => nil}, + state2 => {"iat" => Time.now.to_i, "exp" => nil}, + state3 => {"iat" => Time.now.to_i, "exp" => nil} + }} + allow(instance).to receive(:session).and_return(session) + + expect(instance).to receive(:fail!).with("user_denied", anything) + instance.callback_phase + expect(session["omniauth.oauth2_states"][state2]["exp"]).not_to be_nil + expect(session["omniauth.oauth2_states"][state1]["exp"]).to be_nil + expect(session["omniauth.oauth2_states"][state3]["exp"]).to be_nil + end + + it "prevents replay attacks by rejecting already-used states" do + params["state"] = state + session = {"omniauth.oauth2_states" => { + state => {"iat" => Time.now.to_i, "exp" => Time.now.to_i - 60} + }} + allow(instance).to receive(:session).and_return(session) + + expect(instance).to receive(:fail!).with(:csrf_detected, anything) + instance.callback_phase + end + + it "cleans up all states with exp set" do + state1 = "state1" + state2 = "state2" + state3 = "state3" + session = {"omniauth.oauth2_states" => { + state1 => {"iat" => Time.now.to_i, "exp" => Time.now.to_i}, + state2 => {"iat" => Time.now.to_i, "exp" => Time.now.to_i - 60}, + state3 => {"iat" => Time.now.to_i, "exp" => nil} + }} + allow(instance).to receive(:session).and_return(session) + + instance.send(:cleanup_expired_state) + + expect(session["omniauth.oauth2_states"]).not_to have_key(state1) + expect(session["omniauth.oauth2_states"]).not_to have_key(state2) + expect(session["omniauth.oauth2_states"]).to have_key(state3) + end + describe 'exception handlings' do let(:params) do {"code" => "code", "state" => state} end before do + allow(instance).to receive(:session).and_return({"omniauth.oauth2_states" => {state => {"iat" => Time.now.to_i, "exp" => nil}}}) allow_any_instance_of(OmniAuth::Strategies::OAuth2).to receive(:build_access_token).and_raise(exception) end