1/23/08

Playing with Scala 3: OO, Traits, and Views

OO

Look at graph classes my previous post:

class Node (val label:String) {
var transitions: List[Arc] = Nil
var previous: Node = _
var weight= Float.PositiveInfinity
var visited = false
...

and

class Arc(var start: Node, var end: Node) {
var weight: Float = _
...
Don't you feel that something doesn't really fit? What are previous, visited, and weight are doing in Node? And weight in Arc?

Those properties doesn't really belong to a generic Node or Arc, they're specific for our algorithm. So, why having them there? Let's put those attributes on special cases of Node and Arc: DfsNode and DfsArc. But wait! Putting weight in both doesn't seem right either... we're defining the same thing twice! Weight is some kind of "cross-cutting" attribute.

Traits

Luckily modern OO languages had a very elegant way of solving this kind of problems by means of traits (usually are described as "interfaces with implementation" but a very powerful concept), so let's define our weight as a trait:

trait Weighted {
var weight= Float.PositiveInfinity
}

So we can clean our Node and Arc classes:

class Node(val label:String) {
var transitions: List[Arc] = Nil

override def toString()= {
label
}

def --> (end: Node):Arc={
transitions = new Arc(this,end) :: transitions
transitions.head
}
}

class Arc(var start: Node, var end: Node) {

override def toString()= {
start.label+"-->"+end.label
}
}

Now they look useful for any graph application, and we can define our particular versions for the Dfs algorithm:

class DfsNode (label:String) extends Node(label) with Weighted {
var previous: DfsNode = _
var visited = false

override def toString()= {
super.toString+" w:"+weight+" p:"+previous
}
}

class Arc(var start: Node, var end: Node) {
override def toString()= {
start.label+"-->"+end.label
}
}

And I just need to change the signatures of the functions of the algorithm:

object Dijkstra {

def shortestPath(graph:Set[DfsNode], start: DfsNode, end: DfsNode) = {
var unvisited=graph
start.weight=0
while (!unvisited.isEmpty) {
val vertx=min(unvisited)
vertx.transitions.map(improveDistance(_))
unvisited=unvisited-vertx
}
}

def improveDistance(a:DfsArc) ={
if (a.start.weight+a.weight< a.end.weight) {
a.end.weight=a.start.weight+a.weight
a.end.previous=a.start
}
}

def min(nodes: Set[DfsNode]): DfsNode = {
nodes.reduceLeft((a:DfsNode,b:DfsNode)=>if (a.weight<b.weight) a else b )
}

def pathTo(end:DfsNode):List[DfsNode] = {
if (end == null)
Nil
else
end :: pathTo(end.previous)
}
}

What? It doesn't compile! node.transitions returns a list of Nodes and not DfsNodes! Argh! We've been bitten by the static typing! ("Ha-ha" will shout the Dynamic typing Nelson). I could use a .asInstanceOf[type] everywhere (like a Java cast) but it looks like a kludge and you have to put everywhere! There must be a better solution, after all, Scala is supposed to look elegant, right?

Views (implicit conversions)

Scala has a very interesting way of dealing with that kind of problem: by creating views using implicit conversions, we can convert from one type to other, the compiler will insert the appropriate call when needed. So I defined a conversion from Arc to DfsArc and Node to DfsNode:

object Implicits {
implicit def node2DfsNode(node:Node):DfsNode={
if (node.isInstanceOf[DfsNode])
node.asInstanceOf[DfsNode]
else {
var dfsNode=new DfsNode(node.label)
dfsNode.transitions=node.transitions
dfsNode
}
}

implicit def arc2DfsArc(arc:Arc):DfsArc={
if (arc.isInstanceOf[DfsArc])
arc.asInstanceOf[DfsArc]
else {
new DfsArc(arc.start,arc.end)
}
}
}

And now, everything will happly compile. The only caveat is as I defined the method --> returning an Arc and arc2DfsArc creating a new instance if I pass an Arc, the use of n1-->n2 weight=2 sets the weight in a new object and not in the object in the transitions collection. I'm sure there's a better solution, but meanwhile if we overrride the --> method in DfsNode will work:
def --> (end: DfsNode):DfsArc={
transitions = new DfsArc(this,end) :: transitions
transitions.head.asInstanceOf[DfsArc]
}

(Let me know if you want the full code posted)

1/15/08

Playing with SCALA 2: Dijkstra’s shortest path (Dsp) algorithm

As I continue to play with Scala, I wanted to do something non-trivial, so I decided to to implement Dijstra’s shortest path algorithm
As background, my experience is mostly object oriented OOAD and programming in Java (I use it at work).
The first challenge was getting the algorithm right (the first attempt ended in a beautiful recursive algorithm that just traversed the graph greedily selecting the shortest arc), but wikipedia’s page helped
Just in case, be aware that I’m no Scala expert and no functional programming expert… (I’m no expert on anything at all) I’m just trying to share what I found, I’m sure there’s a better way, feel free to suggest it.

So far, I’ve enjoyed the experience: programming in Scala is easy and fun (at least compared to Java or C++, … yes, yes, it doesn’t require much, I can think many jokes, leave your in the comments ;-)).
Sure, I spend a while on seemingly trivial stuff, but that happens when you learn any language (or framework, or anything new!):
I tried to define a property for a class, in Scala you can do it in the declaration, like this:
class Node (label:String) { … }
But I couldn’t do node.label!! Then I realized I was declaring a private property, if you want accessors you need to use val (and you get a getter) or a var (and you get a getter and a setter), so I’ve changed it to
class Node (val label:String) { … }
And worked :-D
I did a quick search but couldn’t find a quick and easy functional Dsp implementation (seems that there isn’t one, and makes sense as it relies on shared state)
One of the main steps of the algorithm is find the node with the lower cost. I came up with:
nodes.reduceLeft((a:Node, b:Node)=> if (a.weight<b.weight) a else b)
Maybe there’s a better way, but I learned how to use “reduce” and looks quite concise and clear J
Then, you need to check if you can improve the cost of the adjacent nodes using the one you just found, so I defined:

def improveDistance(a:Arc) ={
  if (a.start.weight+a.weight< a.end.weight) {
    a.end.weight=a.start.weight+a.weight
    a.end.previous=a.start
  }
}

Nothing magic, but then I can use “map” to apply it to all the arcs from the node:
  vertx.transitions.map(improveDistance(_))
Don’t be scared by the “_”, just I’m too lazy and glad that Scala allows me not to type “map((a:Arc)=>improveDistance(a))
I didn’t use Scala’s available unit testing frameworks, but I did my sort of unit test, so I declared the arcs;
var a12= new Arc(n1,n2,1.0)
var a13= new Arc(n1,n3,2.2)
var a24= new Arc(n2,n4,1.5)
But it felt too “Java”, I wanted to try something more “DSL-ish”, and Scala can help because almost any character is valid as a method name, so I added the following to the Node class (transitions is the arc list of the node):
def --> (end: Node):Arc={
  transitions = new Arc(this,end) :: transitions
  transitions.head
}

Why? Because now, to add an arc of weight 2 from node N1 to node N2
, I can write:
N1-->N2 weight=2
Isn’t that neat? J
(I used --> instead of -> because I didn’t want to freak out the people scared about changing the meaning of the operators, and Scala uses -> for maps)
Nodes are declared var n1=new Node("Start"), It would be nicer to declare the nodes with “var n2= Node "Node2"but it didn’t bother me that much and couldn’t find a quicker way other than use Case classes. If somebody knows a better way, let me know!
I leave you here with the full code. Please suggest improvements!
package myScala;
class Node (val label:String) {
var transitions: List[Arc] = Nil
var previous: Node = _
var weight= Float.PositiveInfinity
var visited = false
override def toString()= {
label+" w:"+weight+" p:"+previous
}
def --> (end: Node):Arc={
transitions = new Arc(this,end) :: transitions
transitions.head
}
}
class Arc(var start: Node, var end: Node) {
var weight: Float = _
override def toString()= {
start.label+"-->"+end.label+" w:"+weight
}
}
package myScala;
object Dijkstra {
def shortestPath(graph:Set[Node], start: Node, end: Node) = {
var unvisited=graph
start.weight=0
while (!unvisited.isEmpty) {
val vertx=min(unvisited)
vertx.transitions.map(improveDistance(_))
unvisited=unvisited-vertx
}
}
def improveDistance(a:Arc) ={
if (a.start.weight+a.weight< a.end.weight) {
a.end.weight=a.start.weight+a.weight
a.end.previous=a.start
}
}
def min(nodes: Set[Node]): Node = {
nodes.reduceLeft((a:Node, b:Node)
=> if (a.weight<b.weight) a else b)
}
def pathTo(end:Node):List[Node] = {
if (end == null)
Nil
else
end :: pathTo(end.previous)
}
}
package myScala;
import scala.collection.mutable.HashSet
object Test {
/*
n1 --2--> n2--1--> n5
| | |
1 1 3
| | |
V V V
n3--3---> n4--1--> n6
*/
var n1=new Node("Start")
var n2=new Node("Node2")
var n3=new Node("Node3")
var n4=new Node("Node4")
var n5=new Node("Node5")
var n6=new Node("End")
n1-->n2 weight=2
n1-->n3 weight=1
n2-->n4 weight=1
n3-->n4 weight=3
n2-->n5 weight=1
n4-->n6 weight=1
n5-->n6 weight=3
var graph= Set(n1, n2, n3, n4, n5, n6)
def main(args: Array[String]) {
Dijkstra.shortestPath(graph,n1,n6)
println("Path")
Dijkstra.pathTo(n6).reverse.map(
(v:Node)=>println(v.label+" dist:"+v.weight)
)
}
}