如何在django-rest-framework中覆盖默认的create方法

时间:2014-09-04 13:01:39

标签: python django django-rest-framework

我有一个模型,它在管理器中使用不同的create方法。如何覆盖此方法,以便ListCreateAPIView中的post方法使用我编写的方法而不是默认方法。这是方法。

class WeddingInviteManager(models.Manager):


 def create(self, to_user, from_user, wedding):
      wedding_invitation =  self.create(from_user=from_user,to_user=to_user, 
                                        wedding=wedding)
      notification.send([self.to_user], 'wedding_invite',{'invitation':wedding_invitation})

      return wedding_invitation

1 个答案:

答案 0 :(得分:4)

我认为这样做的原因实际上是通知系统。

我建议做这样的事情:

class MyModel(models.Model):
    ...
    def save(self, silent=False, *args, **kwargs):
        # Send notification if this is a new instance that has not been saved
        # before:
        if not silent and not self.pk:
            notification.send([self.to_user], 'wedding_invite', {'invitation': self})

        return super(MyModel, self).save(*args, **kwargs)

但如果你必须,这(理论上)是你如何做到的(代码没有经过测试):

from rest_framework import serializers, viewsets

class MyModelSerializer(serializers.ModelSerializer):
    def save_object(self, obj, **kwargs):
        """
        Save the deserialized object.
        """
        if getattr(obj, '_nested_forward_relations', None):
            # Nested relationships need to be saved before we can save the
            # parent instance.
            for field_name, sub_object in obj._nested_forward_relations.items():
                if sub_object:
                    self.save_object(sub_object)
                setattr(obj, field_name, sub_object)

        #####  EDITED CODE #####
        if obj.pk:
            obj.save(**kwargs)
        else:
            # Creating a new object. This is silly.
            obj = MyModel.objects.create(obj.to_user, obj.from_user, obj.wedding)
        ##### /EDITED CODE #####

        if getattr(obj, '_m2m_data', None):
            for accessor_name, object_list in obj._m2m_data.items():
                setattr(obj, accessor_name, object_list)
            del(obj._m2m_data)

        if getattr(obj, '_related_data', None):
            related_fields = dict([
                (field.get_accessor_name(), field)
                for field, model
                in obj._meta.get_all_related_objects_with_model()
            ])
            for accessor_name, related in obj._related_data.items():
                if isinstance(related, RelationsList):
                    # Nested reverse fk relationship
                    for related_item in related:
                        fk_field = related_fields[accessor_name].field.name
                        setattr(related_item, fk_field, obj)
                        self.save_object(related_item)

                    # Delete any removed objects
                    if related._deleted:
                        [self.delete_object(item) for item in related._deleted]

                elif isinstance(related, models.Model):
                    # Nested reverse one-one relationship
                    fk_field = obj._meta.get_field_by_name(accessor_name)[0].field.name
                    setattr(related, fk_field, obj)
                    self.save_object(related)
                else:
                    # Reverse FK or reverse one-one
                    setattr(obj, accessor_name, related)
            del(obj._related_data)


class MyModelViewSet(viewsets.ModelViewSet):
    serializer_class = MyModelSerializer
    queryset = MyModel.objects.all()