Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
92 changes: 73 additions & 19 deletions src/vaultwardentools/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
"groups": {"sync": False, SYNC_ALL_GROUPS_ID: deepcopy(DEFAULT_CACHE)},
"users": deepcopy(DEFAULT_CACHE),
"organizations": deepcopy(DEFAULT_CACHE),
"organization_users": {},
"collections": {"sync": False, SYNC_ALL_ORGAS_ID: deepcopy(DEFAULT_CACHE)},
"ciphers": {
"sync": False,
Expand Down Expand Up @@ -471,6 +472,7 @@ def __init__(
vaultiersecretid=None,
unmarshall=False,
):
self.vaultiersecretid = None
if unmarshall:
jsond = unmarshall_value(jsond)
self._client = client
Expand Down Expand Up @@ -649,9 +651,9 @@ class Organization(BWFactory):
"""."""

def __init__(self, *a, **kw):
ret = super(Organization, self).__init__(*a, **kw)
self.name = None
super(Organization, self).__init__(*a, **kw)
self._complete = False
return ret


class Organizationuseruserdetails(BWFactory):
Expand All @@ -660,6 +662,10 @@ class Organizationuseruserdetails(BWFactory):

class Groupcollectiondetails(BWFactory):
"""."""
id: str
hidePasswords: bool
readOnly: bool
manage: bool

def to_payload(self):
return {
Expand All @@ -672,6 +678,7 @@ def to_payload(self):

class Groupdetails(BWFactory):
"""."""
name: str

def load_single(self, jsond=None):
super(Groupdetails, self).load(jsond)
Expand Down Expand Up @@ -775,6 +782,7 @@ class Collection(BWFactory):
"""."""

def __init__(self, *a, **kw):
self.organizationId = None
BWFactory.__init__(self, *a, **kw)
self.externalId = getattr(self, "externalId", None)
self._orga = None
Expand All @@ -794,6 +802,7 @@ class Collectiondetails(BWFactory):
"""."""

def __init__(self, *a, **kw):
self.organizationId = None
BWFactory.__init__(self, *a, **kw)
self.externalId = getattr(self, "externalId", None)
self._orga = None
Expand Down Expand Up @@ -1427,9 +1436,10 @@ def cache_group(self, r, cache_key=SYNC_ALL_GROUPS_ID, **kw):
return self._cache_objects(r, cache=self._cache["groups"], cache_key=cache_key, uniques=["id"], **kw)

def cache_collection(self, r, cache_key=SYNC_ALL_ORGAS_ID, **kw):
return self._cache_objects(
r, cache=self._cache["collections"], cache_key=cache_key, **kw
)
return self._cache_objects(r, cache=self._cache["collections"], cache_key=cache_key, **kw)

def cache_organization_users(self, r, org_id, **kw):
return self._cache_objects(r, org_id, cache=self._cache["organization_users"], uniques=["id", "email"], **kw)

def add_cipher(self, ret, obj, **kw):
return self._cache_object(obj, cache=ret)
Expand Down Expand Up @@ -2485,15 +2495,47 @@ def get_user(self, email=None, name=None, id=None, user=None, sync=None):
exc.criteria = criteria
raise exc

def get_users_from_organization(self, orga, include_groups=False, token=None):
def get_users_from_organization(self, orga, include_groups=False, sync=False, token=None):
token = self.get_token(token)
orga = self.get_organization(orga, token=token)
res = self.r(f"/api/organizations/{orga.id}/users" + ("?includeGroups=true" if include_groups else ""),
method="get")
users = {}
for user in res.json()["data"]:
users[user["id"]] = BWFactory.construct(user, client=self, unmarshall=True, )
return users
self.cache_organization_users(BWFactory.construct(user, client=self, unmarshall=True, ), orga.id)
return self.cache_organization_users([], orga.id)

def get_user_from_organization(self, orga, user, sync=False, token=None):
token = self.get_token(token)
if isinstance(user, Organizationuseruserdetails):
if not sync:
return user
else:
self.get_users_from_organization(orga, token, sync)
_id = self.item_or_id(user)
orga = self.get_organization(orga, token=token)
try:
cache = self.cache_organization_users([], orga.id)
except KeyError:
cache = self.get_users_from_organization(orga, token, sync)
try:
return cache["id"][_id]
except KeyError:
pass
try:
return cache["email"][_id]
except KeyError:
users = self.get_users_from_organization(orga, token, sync)
try:
return users["id"][_id]
except KeyError:
pass
try:
return users["email"][_id]
except KeyError:
pass
exc = OrganizationNotFound(f"No such user found {user}")
exc.criteria = [orga]
raise exc

def assert_bw_response(
self, response, expected_status_codes=None, expected_callback=None, *a, **kw
Expand Down Expand Up @@ -3547,7 +3589,7 @@ def confirm_invitation(self, orga, email, name=None, sync=None, token=None):
L.info(f"Confirmed user {email} / {user_id} in orga {orga.name} / {orga.id}")
return acl

def create_group(self, group, orga, users=[], collections=[], token=None):
def create_group(self, orga, group, users=[], collections=[], token=None):
v, i = self.version()
if i and (v < API_CHANGES["1.27.0"]):
raise WrongVersionOfServer(f"the server has version {v} and doesn't support groups")
Expand All @@ -3566,15 +3608,13 @@ def create_group(self, group, orga, users=[], collections=[], token=None):
d.load_single()
return d

def get_groups(self, orga, sync=None, cache=None, token=None):
def get_groups(self, orga, sync=False, cache=None, token=None):
v, i = self.version()
if i and (v < API_CHANGES["1.27.0"]):
raise WrongVersionOfServer(f"the server has version {v} and doesn't support groups")
token = self.get_token(token=token)
orga = self.get_organization(orga, token=token)
_CACHE = self._cache["groups"]
if sync is None:
sync = False
if cache is None:
cache = True
if cache is False or sync:
Expand Down Expand Up @@ -3680,14 +3720,17 @@ def get_users_from_group(self, group, orga=None, sync=None, token=None):
orga = self.get_organization(group.organizationId, token=token)
users_org = self.get_users_from_organization(orga, token=token)
for user_id in resp.json():
users[user_id] = deepcopy(users_org[user_id])
users[user_id] = deepcopy(users_org["id"][user_id])
return users

def edit_group(self,
group,
orga=None,
users=None,
collections=None,
readonly=False,
hidepasswords=False,
manage=False,
sync=None, token=None):
v, i = self.version()
if i and (v < API_CHANGES["1.27.0"]):
Expand All @@ -3702,12 +3745,23 @@ def edit_group(self,
payload = {
"users": get_ids(users),
"name": group.name,
"collections": []
"collections": [],
}
if len(collections) > 0 and isinstance(collections[0], Groupcollectiondetails):
payload["collections"] = [i.to_payload() for i in collections]
else:
payload["collections"] = collections
if collections:
if not isinstance(collections, list) and not isinstance(collections, str):
collections = [collections]
if isinstance(collections[0], Groupcollectiondetails):
payload["collections"] = [col.to_payload() for col in collections]
elif isinstance(collections, str):
payload["collections"] = collections
else:
orga = self.get_organization(group.organizationId, token=token, sync=sync)
dcollections = self.collections_to_payloads(
collections, orga=orga, token=token
)
payload["collections"] = self.compute_accesses(
dcollections, readonly=readonly, hidepasswords=hidepasswords, manage=manage
)["payloads"]
resp = self.r(f"/api/organizations/{group.organizationId}/groups/{_id}", json=payload, method="put",
token=token)
self.assert_bw_response(resp, expected_status_codes=[200, 500])
Expand Down
Loading