将案例类字段与Scala中另一个案例类的子字段进行比较

时间:2019-05-12 12:52:27

标签: scala functional-programming

我有以下3个案例类:

case class Profile(name: String,
                   age: Int,
                   bankInfoData: BankInfoData,
                   userUpdatedFields: Option[UserUpdatedFields])

case class BankInfoData(accountNumber: Int,
                        bankAddress: String,
                        bankNumber: Int,
                        contactPerson: String,
                        phoneNumber: Int,
                        accountType: AccountType)

case class UserUpdatedFields(contactPerson: String,
                             phoneNumber: Int,
                             accountType: AccountType)

这只是枚举,但我还是添加了:

sealed trait AccountType extends EnumEntry

object AccountType extends Enum[AccountType] {
  val values: IndexedSeq[AccountType] = findValues

  case object Personal extends AccountType

  case object Business extends AccountType

}

我的任务是-我需要编写一个funcc Profile,并将UserUpdatedFields(所有字段)与BankInfoData中的某些字段进行比较...此功能是查找要更新的字段。

所以我写了这个func:

def findDiff(profile: Profile): Seq[String] = {
  var listOfFieldsThatChanged: List[String] = List.empty
  if (profile.bankInfoData.contactPerson != profile.userUpdatedFields.get.contactPerson){
    listOfFieldsThatChanged = listOfFieldsThatChanged :+ "contactPerson"
  }
  if (profile.bankInfoData.phoneNumber != profile.userUpdatedFields.get.phoneNumber) {
    listOfFieldsThatChanged = listOfFieldsThatChanged :+ "phoneNumber"
  }
  if (profile.bankInfoData.accountType != profile.userUpdatedFields.get.accountType) {
    listOfFieldsThatChanged = listOfFieldsThatChanged :+ "accountType"
  }
  listOfFieldsThatChanged
}

val profile =
  Profile(
    "nir",
    34,
    BankInfoData(1, "somewhere", 2, "john", 123, AccountType.Personal),
    Some(UserUpdatedFields("lee", 321, AccountType.Personal))
  )

findDiff(profile)

它可以工作,但是想要更干净的..任何建议吗?

3 个答案:

答案 0 :(得分:3)

一个简单的改进就是引入特质

trait Fields {
  val contactPerson: String
  val phoneNumber: Int
  val accountType: AccountType

  def findDiff(that: Fields): Seq[String] = Seq(
    Some(contactPerson).filter(_ != that.contactPerson).map(_ => "contactPerson"),
    Some(phoneNumber).filter(_ != that.phoneNumber).map(_ => "phoneNumber"),
    Some(accountType).filter(_ != that.accountType).map(_ => "accountType")
  ).flatten
}

case class BankInfoData(accountNumber: Int,
                          bankAddress: String,
                          bankNumber: Int,
                          contactPerson: String,
                          phoneNumber: Int,
                          accountType: String) extends Fields

case class UserUpdatedFields(contactPerson: String,
                           phoneNumber: Int,
                           accountType: AccountType) extends Fields

因此可以打电话

BankInfoData(...). findDiff(UserUpdatedFields(...))

如果要进一步改进并避免多次命名所有字段,例如可以使用shapeless进行编译。不完全相同,但是类似this的入门。或使用反射来像this answer这样运行。

答案 1 :(得分:2)

每个案例类都扩展了Product接口,因此我们可以使用它来将案例类转换为(字段,值)元素集。然后,我们可以使用设置操作来找到差异。例如,

  def findDiff(profile: Profile): Seq[String] = {
    val userUpdatedFields = profile.userUpdatedFields.get
    val bankInfoData = profile.bankInfoData

    val updatedFieldsMap = userUpdatedFields.productElementNames.zip(userUpdatedFields.productIterator).toMap
    val bankInfoDataMap = bankInfoData.productElementNames.zip(bankInfoData.productIterator).toMap
    val bankInfoDataSubsetMap = bankInfoDataMap.view.filterKeys(userUpdatedFieldsMap.keys.toList.contains)
    (bankInfoDataSubsetMap.toSet diff updatedFieldsMap.toSet).toList.map { case (field, value) => field }
  }

现在findDiff(profile)应该输出List(phoneNumber, contactPerson)。请注意,我们使用的是来自Scala 2.13的productElementNames来获取文件名,然后我们将其压缩为相应的值

userUpdatedFields.productElementNames.zip(userUpdatedFields.productIterator)

我们也依靠filterKeysdiff

答案 2 :(得分:1)

如果这是将case类转换为map的简便方法,那将是一件非常容易完成的任务。不幸的是,案例类在Scala 2.12中尚未提供现成的功能(正如Mario提到的那样,在Scala 2.13中将很容易实现)。

有一个名为shapeless的库,其中提供了一些通用的编程实用程序。例如,我们可以使用toMapRecord来自无形的扩展函数ToMap

object Mappable {
  implicit class RichCaseClass[X](val x: X) extends AnyVal {
    import shapeless._
    import ops.record._

    def toMap[L <: HList](
        implicit gen: LabelledGeneric.Aux[X, L],
        toMap: ToMap[L]
    ): Map[String, Any] =
      toMap(gen.to(x)).map{
        case (k: Symbol, v) => k.name -> v
      }
    }
}

然后我们可以将其用于findDiff

def findDiff(profile: Profile): Seq[String] = {
  import Mappable._

  profile match {
    case Profile(_, _, bankInfo, Some(userUpdatedFields)) =>
      val bankInfoMap = bankInfo.toMap
      userUpdatedFields.toMap.toList.flatMap{
        case (k, v) if bankInfoMap.get(k).exists(_ != v) => Some(k)
        case _ => None
      }
    case _ => Seq()
  }
}
相关问题