diff --git a/kolibri/core/content/test/test_public_api.py b/kolibri/core/content/test/test_public_api.py index b625f4fdd45..62d5d7b828d 100644 --- a/kolibri/core/content/test/test_public_api.py +++ b/kolibri/core/content/test/test_public_api.py @@ -4,6 +4,7 @@ from le_utils.constants import content_kinds from rest_framework.test import APITestCase +from kolibri.core.content import base_models from kolibri.core.content import models as content from kolibri.core.content.constants.schema_versions import CONTENT_SCHEMA_VERSION from kolibri.core.content.test.test_channel_upgrade import ChannelBuilder @@ -40,11 +41,20 @@ def _assert_data(self, Model, queryset): response = self.client.get( reverse("kolibri:core:importmetadata-detail", kwargs={"pk": self.node.id}) ) + fields = Model._meta.fields + BaseModel = getattr(base_models, Model.__name__, Model) + field_names = {field.column for field in BaseModel._meta.fields} + if hasattr(BaseModel, "_mptt_meta"): + field_names.add(BaseModel._mptt_meta.parent_attr) + field_names.add(BaseModel._mptt_meta.tree_id_attr) + field_names.add(BaseModel._mptt_meta.left_attr) + field_names.add(BaseModel._mptt_meta.right_attr) + field_names.add(BaseModel._mptt_meta.level_attr) for response_data, obj in zip(response.data[Model._meta.db_table], queryset): # Ensure that we are not returning any empty objects self.assertNotEqual(response_data, {}) - for field in Model._meta.fields: - if field.column in response_data: + for field in fields: + if field.column in field_names: value = response_data[field.column] if hasattr(field, "from_db_value"): value = field.from_db_value(value, None, connection, None)